this code adds addtl outcomes to the study, can be appended to the end fo study.notebook

In [None]:
# =============================================================================
# CELL 14: UNIFIED AIPW MULTI-OUTCOME ENGINE
# =============================================================================
print("\n--- CELL 14: Running AIPW for Secondary Endpoints ---")

# 1. Standardize Endpoint Definitions
ADDITIONAL_ENDPOINTS = {
    'AKI_7':       ('date_AKI_30', 7),
    'MORTALITY_30': ('date_DEATH', 30),
    'THYROID_90':   ('date_THYROID_90', 90)
}

# Dictionary to store full results for plotting/policy analysis
all_outcome_data = {}

for name, (date_col, window) in ADDITIONAL_ENDPOINTS.items():
    print(f"\n>>> PROCESSING OUTCOME: {name} ({window} day window) <<<")
    
    # Define Binary Y
    Y_vec_new = (
        (df_final[date_col] - df_final['index_date']).dt.days <= window
    ).fillna(False).astype(int).values
    
    # Audit event counts
    n_events = Y_vec_new.sum()
    print(f"  Event Count: {n_events} ({100*n_events/len(df_final):.2f}%)")
    
    if n_events < 10:
        print(f"  SKIPPING {name}: Insufficient events for robust SuperLearner.")
        continue

    # Run the Gold Standard Engine (defined in Cell 9)
    # This uses the same X_final and T_final used for AKI_30
    stats, preds_new = run_cross_fitted_aipw(X_final, T_final, Y_vec_new, n_folds=5)
    
    # Store for next cells
    all_outcome_data[name] = {
        'stats': stats,
        'preds': preds_new,
        'Y': Y_vec_new
    }

print("\n--- AIPW Inference Complete for all Endpoints ---")

In [None]:
# =============================================================================
# CELL 15: CONSOLIDATED INFERENCE & EMPIRICAL CALIBRATION
# =============================================================================
print("\n--- CELL 15: Consolidated Forest Plot & Calibration ---")

summary_rows = []

# Include AKI_30 from your previous run if available
if 'stats' in globals() and 'outcome_name' in globals() and outcome_name == 'AKI_30':
    summary_rows.append({'Outcome': 'AKI_30', **stats})

# Add the new ones
for name, data in all_outcome_data.items():
    summary_rows.append({'Outcome': name, **data['stats']})

df_results_all = pd.DataFrame(summary_rows)

# Apply Empirical Calibration (Using the logic from your Cell 2.5)
# We treat the AIPW ATE and SE as the "Observed" effect
# Note: Since calibration usually works on Log-RR, we convert AIPW Risk 1 / Risk 0 to Log Space
df_results_all['HR_Cox'] = df_results_all['RR'] # Proxying RR as HR for the calibration function
df_results_all['HR_CI_Low'] = np.exp(np.log(df_results_all['RR']) - 1.96*(df_results_all['SE']/df_results_all['ATE'].abs())) # Approx SE for log space
df_results_all['HR_CI_High'] = np.exp(np.log(df_results_all['RR']) + 1.96*(df_results_all['SE']/df_results_all['ATE'].abs()))

# Calibrate!
df_calibrated = calibrate_estimates(df_results_all)

# Forest Plot of Risk Differences (ATE)
plt.figure(figsize=(10, 6))
y_pos = np.arange(len(df_calibrated))
plt.errorbar(df_calibrated['ATE'], y_pos, 
             xerr=[df_calibrated['ATE'] - df_calibrated['CI_Lower'], 
                   df_calibrated['CI_Upper'] - df_calibrated['ATE']], 
             fmt='o', color='black', capsize=5, label='95% CI (AIPW)')

plt.axvline(0, color='red', linestyle='--')
plt.yticks(y_pos, df_calibrated['Outcome'])
plt.xlabel("Risk Difference (ATE)")
plt.title("Consolidated Causal Effects (AIPW + Cross-Fitting)")
plt.grid(True, alpha=0.3)
plt.show()

display(df_calibrated[['Outcome', 'Risk_1', 'Risk_0', 'ATE', 'RR', 'P_Value', 'P_Calibrated']].round(4))

In [None]:
# =============================================================================
# CELL 16: MULTI-OUTCOME POLICY FRONTIER (THE TRADE-OFF MAP)
# =============================================================================
print("\n--- CELL 16: Multi-Outcome Policy Frontier ---")

# We want to see how the eGFR Rule < 30 and < 45 impacts ALL outcomes
policy_comparison = []

for name, data in all_outcome_data.items():
    mu1 = data['preds']['mu1']
    mu0 = data['preds']['mu0']
    pi  = data['preds']['pi']
    Y   = data['Y']
    
    # Calculate DR Influence Functions
    gamma_1 = mu1 + (T_vec / pi) * (Y - mu1)
    gamma_0 = mu0 + ((1 - T_vec) / (1 - pi)) * (Y - mu0)
    
    for pol_name, func in policies.items():
        if pol_name in ['Always Contrast (100%)', 'Never Contrast (0%)']: continue
        
        d_vec = func(df_final)
        psi_i = d_vec * gamma_1 + (1 - d_vec) * gamma_0
        
        policy_comparison.append({
            'Outcome': name,
            'Policy': pol_name,
            'Risk': np.mean(psi_i),
            'Risk_SE': np.std(psi_i) / np.sqrt(len(psi_i)),
            'Withholding': np.mean(1 - d_vec)
        })

df_policy_map = pd.DataFrame(policy_comparison)

# Plotting the trade-offs
fig, ax1 = plt.subplots(figsize=(12, 7))

# We will plot AKI_7 and Mortality_30 on different axes if scales differ, 
# but here we use a normalized % change from "Current Practice"
for outcome in df_policy_map['Outcome'].unique():
    subset = df_policy_map[df_policy_map['Outcome'] == outcome]
    # Normalize risk to % of mean risk
    base_risk = subset['Risk'].mean()
    plt.errorbar(subset['Withholding'], (subset['Risk'] / base_risk), 
                 yerr=(1.96 * subset['Risk_SE'] / base_risk), 
                 fmt='-o', label=outcome, alpha=0.7)

plt.title("Policy Impact Across Multiple Endpoints\n(Normalized Risk vs. Withholding Rate)")
plt.xlabel("Proportion of Patients Withheld Contrast")
plt.ylabel("Relative Change in Estimated Risk")
plt.legend()
plt.grid(True, linestyle=':')
plt.show()

# Final ITE Integration: Who benefits most for Mortality vs AKI?
print("--- Correlation of Individual Treatment Effects (ITE) ---")
# If ITE for AKI and ITE for Mortality are correlated, decisions are easy. 
# If not, we have a 'Preference Sensitive' clinical zone.
ite_aki = all_outcome_data['AKI_7']['preds']['mu0'] - all_outcome_data['AKI_7']['preds']['mu1']
ite_mort = all_outcome_data['MORTALITY_30']['preds']['mu0'] - all_outcome_data['MORTALITY_30']['preds']['mu1']

plt.figure(figsize=(8, 6))
plt.hexbin(ite_aki, ite_mort, gridsize=30, cmap='YlOrRd')
plt.axhline(0, color='black', lw=1); plt.axvline(0, color='black', lw=1)
plt.xlabel("ITE AKI-7 (Benefit of Contrast)")
plt.ylabel("ITE Mortality-30 (Benefit of Contrast)")
plt.title("Clinical Equipoise Map: AKI vs Mortality Benefit")
plt.colorbar(label='Patient Density')
plt.show()