# 05b Rebuild Top Groups and Refit Models

This notebook rebuilds a model for the top groups selected in `03a_GI_validate_CATE_estimators.ipynb` / `04a_GI_rank_CATE_estimators.ipynb`.

- Identify units that fall in the top quantile [0.8, 1] across 12 CV folds with frequency >= 0.33.
- Compute Neyman t-statistic and p-value for that subgroup.
- Retune estimators with 4-fold CV and refit on all units.


In [14]:
from pathlib import Path

# Ensure standard output directories exist
for p in [
    Path("output"),
    Path("output/analysis"),
    Path("output/params"),
    Path("output/figures"),
    Path("output/tables"),
]:
    p.mkdir(parents=True, exist_ok=True)


In [15]:
# Setup and imports
import os
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
import joblib

from methods.cate_estimator_validation import (
    make_top_ensemble,
)
from methods.causal_functions import get_subgroup_t_statistic, get_Neyman_ATE, get_subgroup_CATE_std

# Config
outcome = "fausebal"  # adjust as needed
DATA_DIR = Path("output/analysis")
PARAMS_DIR = Path("output/params")
ANALYSIS_DIR = DATA_DIR / outcome
PARAMS_PATH = PARAMS_DIR / outcome / f"{outcome}_tuned_params.pkl"
IMPUTATION_META = PARAMS_DIR / outcome / "analysis_imputation_meta.pkl"

assert IMPUTATION_META.exists(), "Missing imputation metadata from 02. Run 02_tune_CATE_estimators.ipynb first."
with open(IMPUTATION_META, 'rb') as f:
    meta = pickle.load(f)

features = meta["features"]
treatment_var = meta.get("treatment_var", "TREATED")
all_outcomes = meta.get("outcomes", [outcome])

# Load fitted libraries and top estimator names from 03a/04a cache
FITTED_LIBS_PATH = ANALYSIS_DIR / f"{outcome}_fitted_libraries.pkl"
TOP_NAMES_PATH = ANALYSIS_DIR / f"{outcome}_top_estimator_names.pkl"

fitted_libraries = joblib.load(FITTED_LIBS_PATH)
# If top names not present, fall back to all estimator names
if TOP_NAMES_PATH.exists():
    top_estimator_names = joblib.load(TOP_NAMES_PATH)
else:
    top_estimator_names = list(fitted_libraries["pert_none"].keys())

# Build 12-fold ensemble (3 CV splits × 4 folds)
top_ensemble = make_top_ensemble(fitted_libraries, top_estimator_names)
print(f"Built ensemble with {top_ensemble.n_splits} folds")


Built ensemble with 12 folds


In [42]:
# Build [0.9, 1] subgroup by frequency across 12 folds
dir_neg = False
q_bot, q_top = (0.9, 1.0) if not dir_neg else (0.0, 0.1)

n_folds = top_ensemble.n_splits
n_samples = len(top_ensemble.y)
indicators = np.zeros((n_samples, n_folds), dtype=bool)

for fold in range(n_folds):
    ind = top_ensemble.results[fold].get_subgroup_indicator(q_bot, q_top, kind="all")
    indicators[:, fold] = ind

freq = indicators.mean(axis=1)
subgroup_indicator = freq >= 0.5

# Load aligned trainval data to get y and t
trainval_df = pd.read_csv(ANALYSIS_DIR / "trainval_data.csv")
y = trainval_df[outcome].values
if 'TREATED' in trainval_df.columns:
    t = trainval_df['TREATED'].values.astype(int)
elif treatment_var in trainval_df.columns:
    t = trainval_df[treatment_var].values.astype(int)
else:
    raise KeyError("trainval data must include 'TREATED' or treatment_var column")

from scipy.stats import norm
ATE = get_Neyman_ATE(y[subgroup_indicator], t[subgroup_indicator])
# t-stat vs zero to align sign with ATE
if (t[subgroup_indicator].sum() > 0) and ((1 - t[subgroup_indicator]).sum() > 0):
    CATE_std = get_subgroup_CATE_std(y, t, subgroup_indicator)
    t_stat = ATE / CATE_std if (np.isfinite(CATE_std) and CATE_std > 0) else np.nan
    p_value = 2 * (1 - norm.cdf(abs(t_stat))) if np.isfinite(t_stat) else np.nan
else:
    t_stat, p_value = np.nan, np.nan

print({
    "n_samples": int(n_samples),
    "subgroup_size": int(subgroup_indicator.sum()),
    "ATE": float(ATE),
    "t_stat": None if t_stat is np.nan else float(t_stat),
    "p_value": None if p_value is np.nan else float(p_value),
})


{'n_samples': 28816, 'subgroup_size': 2173, 'ATE': 0.03796465879799213, 't_stat': 2.0065226623383516, 'p_value': 0.044800513346757054}


In [41]:
# Holdout subgroup via fold models
holdout_df = pd.read_csv(ANALYSIS_DIR / "holdout_data.csv")
X_hold = holdout_df[features].copy().apply(pd.to_numeric, errors='coerce').fillna(0.0).values
y_hold = holdout_df[outcome].values
if 'TREATED' in holdout_df.columns:
    t_hold = holdout_df['TREATED'].values.astype(int)
elif treatment_var in holdout_df.columns:
    t_hold = holdout_df[treatment_var].values.astype(int)
else:
    raise KeyError("Holdout data must include 'TREATED' or treatment_var column")

library_names = [name for name in ["pert_none", "pert_cv_0", "pert_cv_1"] if name in fitted_libraries]
fold_indicators = []
for lib_name in library_names:
    lib = fitted_libraries[lib_name]
    n_folds_lib = len(next(iter(lib.values())).results)
    for fold in range(n_folds_lib):
        taus_train = []
        for est_name in top_estimator_names:
            if est_name not in lib: continue
            est = lib[est_name]
            taus_train.append(est.results[fold].tau_train)
        tau_train_ens = np.mean(np.vstack(taus_train), axis=0)
        thr = np.quantile(tau_train_ens, q_bot if not dir_neg else q_top)
        tau_hold_stack = []
        for est_name in top_estimator_names:
            if est_name not in lib: continue
            est = lib[est_name]
            tau_hold_stack.append(est.predict_on_fold(fold, X_hold))
        tau_hold_ens = np.mean(np.vstack(tau_hold_stack), axis=0)
        ind = (tau_hold_ens >= thr) if not dir_neg else (tau_hold_ens <= thr)
        fold_indicators.append(ind.astype(bool))

fold_indicators = np.vstack(fold_indicators)
freq_hold = fold_indicators.mean(axis=0)
subgroup_hold = freq_hold >= 0.5

from scipy.stats import norm
ATE_hold = get_Neyman_ATE(y_hold[subgroup_hold], t_hold[subgroup_hold]) if subgroup_hold.any() else np.nan
# t-stat vs zero to align sign with ATE
if subgroup_hold.any() and (t_hold[subgroup_hold].sum() > 0) and ((1 - t_hold[subgroup_hold]).sum() > 0):
    CATE_std_hold = get_subgroup_CATE_std(y_hold, t_hold, subgroup_hold)
    t_stat_hold = ATE_hold / CATE_std_hold if (np.isfinite(CATE_std_hold) and CATE_std_hold > 0) else np.nan
    p_value_hold = 2 * (1 - norm.cdf(abs(t_stat_hold))) if np.isfinite(t_stat_hold) else np.nan
else:
    t_stat_hold, p_value_hold = np.nan, np.nan

print({
    "n_holdout": int(len(holdout_df)),
    "subgroup_size_holdout": int(subgroup_hold.sum()),
    "ATE_holdout": None if ATE_hold is np.nan else float(ATE_hold),
    "t_stat_holdout": None if t_stat_hold is np.nan else float(t_stat_hold),
    "p_value_holdout": None if p_value_hold is np.nan else float(p_value_hold),
})


{'n_holdout': 7204, 'subgroup_size_holdout': 544, 'ATE_holdout': 0.008155523944997634, 't_stat_holdout': 0.2222414112465909, 'p_value_holdout': 0.8241259585586229}


In [43]:
# Pooled stats (trainval + holdout)
y_all = np.asarray(trainval_df[outcome].values, dtype=float)
t_all = np.asarray(
    trainval_df['TREATED'].values.astype(int) if 'TREATED' in trainval_df.columns
    else trainval_df[treatment_var].values.astype(int),
    dtype=int,
)

subgroup_all = np.concatenate([
    np.asarray(subgroup_indicator, dtype=bool),
    np.asarray(subgroup_hold, dtype=bool),
])
y_all_pool = np.concatenate([y_all, np.asarray(y_hold, dtype=float)])
t_all_pool = np.concatenate([t_all, np.asarray(t_hold, dtype=int)])

from scipy.stats import norm
ATE_all = get_Neyman_ATE(y_all_pool[subgroup_all], t_all_pool[subgroup_all]) if subgroup_all.any() else np.nan
# t-stat vs zero to align sign with ATE
if subgroup_all.any() and (t_all_pool[subgroup_all].sum() > 0) and ((1 - t_all_pool[subgroup_all]).sum() > 0):
    CATE_std_all = get_subgroup_CATE_std(y_all_pool, t_all_pool, subgroup_all)
    t_stat_all = ATE_all / CATE_std_all if (np.isfinite(CATE_std_all) and CATE_std_all > 0) else np.nan
    p_value_all = 2 * (1 - norm.cdf(abs(t_stat_all))) if np.isfinite(t_stat_all) else np.nan
else:
    t_stat_all, p_value_all = np.nan, np.nan

print({
    "n_all": int(len(y_all_pool)),
    "subgroup_size_all": int(np.sum(subgroup_all)),
    "ATE_all": None if (ATE_all is np.nan or np.isnan(ATE_all)) else float(ATE_all),
    "t_stat_all": None if (t_stat_all is np.nan or not np.isfinite(t_stat_all)) else float(t_stat_all),
    "p_value_all": None if (p_value_all is np.nan or not np.isfinite(p_value_all)) else float(p_value_all),
})


{'n_all': 36020, 'subgroup_size_all': 2717, 'ATE_all': 0.032639929933412914, 't_stat_all': 1.9404589115295334, 'p_value_all': 0.052323943028677045}


In [39]:
# Model-based inference (ATE + 95% CI) per estimator using CausalML estimate_ate on fixed subgroups
# Uses pre-trained fold models saved in fitted_libraries; if unavailable, skips that fold

from scipy.stats import norm
import importlib

# Library with fold models
lib_key = "pert_none" if "pert_none" in fitted_libraries else list(fitted_libraries.keys())[0]
lib = fitted_libraries[lib_key]

# Ensure subgroup arrays (as NumPy arrays)
trainval_df = pd.read_csv(ANALYSIS_DIR / "trainval_data.csv")
X_tv = np.asarray(trainval_df[features].copy().apply(pd.to_numeric, errors='coerce').fillna(0.0).values, dtype=float)
y_tv = np.asarray(trainval_df[outcome].values, dtype=float)
t_tv = np.asarray(
    trainval_df['TREATED'].values.astype(int) if 'TREATED' in trainval_df.columns else trainval_df[treatment_var].values.astype(int),
    dtype=int,
)
mask_tv = np.asarray(subgroup_indicator, dtype=bool)
X_sub = X_tv[mask_tv]
y_sub = y_tv[mask_tv]
t_sub = t_tv[mask_tv]

# Holdout arrays (from earlier cell), force NumPy dtypes
if 'X_hold' not in globals():
    holdout_df = pd.read_csv(ANALYSIS_DIR / "holdout_data.csv")
    X_hold = holdout_df[features].copy().apply(pd.to_numeric, errors='coerce').fillna(0.0).values
X_hold = np.asarray(X_hold, dtype=float)
y_hold = np.asarray(y_hold, dtype=float)
t_hold = np.asarray(t_hold, dtype=int)
mask_hold = np.asarray(subgroup_hold, dtype=bool)
X_hold_sub = X_hold[mask_hold] if mask_hold.any() else np.empty((0, X_tv.shape[1]))
y_hold_sub = y_hold[mask_hold] if mask_hold.any() else np.array([], dtype=float)
t_hold_sub = t_hold[mask_hold] if mask_hold.any() else np.array([], dtype=int)

# Propensity helper: dynamic import to avoid linter complaints
def compute_propensity(X, t):
    try:
        cm_prop = importlib.import_module("causalml.propensity")
        ElasticNetPropensityModel = getattr(cm_prop, "ElasticNetPropensityModel")
        model = ElasticNetPropensityModel()
        p = model.fit_predict(np.asarray(X, dtype=float), np.asarray(t, dtype=int))
        return np.asarray(p, dtype=float).reshape(-1)
    except Exception:
        from sklearn.linear_model import LogisticRegression
        Xn = np.asarray(X, dtype=float)
        tn = np.asarray(t, dtype=int)
        lr = LogisticRegression(max_iter=1000)
        lr.fit(Xn, tn)
        return lr.predict_proba(Xn)[:, 1]

alpha = 0.05
zcrit = norm.ppf(1 - alpha/2)

rows = []
for est_name in top_estimator_names:
    if est_name not in lib:
        continue
    est = lib[est_name]
    n_folds = len(est.results)

    # Determine treated group id (binary 0/1 default)
    ml0 = est.results[0].meta_learner
    try:
        control = getattr(ml0, 'control_name', 0)
        groups = np.array(getattr(ml0, 't_groups', np.array([1])))
        treat_groups = groups[groups != control]
        treat_group = int(treat_groups[0]) if treat_groups.size > 0 else 1
    except Exception:
        treat_group = 1

    # Precompute propensities for each dataset
    p_tv = compute_propensity(X_sub, t_sub)
    p_hold = compute_propensity(X_hold_sub, t_hold_sub) if X_hold_sub.shape[0] > 0 else np.array([])

    # Collect per-fold ATE and SE from CI widths
    ates_tv, ses_tv = [], []
    ates_hold, ses_hold = [], []
    ates_all, ses_all = [], []

    for f in range(n_folds):
        ml = est.results[f].meta_learner
        # Train/val
        try:
            ate_arr, lb_arr, ub_arr = ml.estimate_ate(X_sub, t_sub, y_sub, p={treat_group: p_tv}, pretrain=True)
            a = float(np.asarray(ate_arr).reshape(-1)[0])
            lb = float(np.asarray(lb_arr).reshape(-1)[0])
            ub = float(np.asarray(ub_arr).reshape(-1)[0])
            se = (ub - lb) / (2 * zcrit) if np.isfinite(ub) and np.isfinite(lb) else np.nan
            ates_tv.append(a)
            ses_tv.append(se)
        except Exception:
            pass
        # Holdout
        if X_hold_sub.shape[0] > 0:
            try:
                ate_arr_h, lb_arr_h, ub_arr_h = ml.estimate_ate(X_hold_sub, t_hold_sub, y_hold_sub, p={treat_group: p_hold}, pretrain=True)
                ah = float(np.asarray(ate_arr_h).reshape(-1)[0])
                lbh = float(np.asarray(lb_arr_h).reshape(-1)[0])
                ubh = float(np.asarray(ub_arr_h).reshape(-1)[0])
                seh = (ubh - lbh) / (2 * zcrit) if np.isfinite(ubh) and np.isfinite(lbh) else np.nan
                ates_hold.append(ah)
                ses_hold.append(seh)
            except Exception:
                pass
        # Pooled
        try:
            if X_hold_sub.shape[0] > 0:
                X_all = np.vstack([np.asarray(X_sub, dtype=float), np.asarray(X_hold_sub, dtype=float)])
                t_all = np.concatenate([np.asarray(t_sub, dtype=int), np.asarray(t_hold_sub, dtype=int)])
                y_all = np.concatenate([np.asarray(y_sub, dtype=float), np.asarray(y_hold_sub, dtype=float)])
                p_all = compute_propensity(X_all, t_all)
            else:
                X_all = np.asarray(X_sub, dtype=float)
                t_all = np.asarray(t_sub, dtype=int)
                y_all = np.asarray(y_sub, dtype=float)
                p_all = p_tv
            ate_arr_a, lb_arr_a, ub_arr_a = ml.estimate_ate(X_all, t_all, y_all, p={treat_group: p_all}, pretrain=True)
            aa = float(np.asarray(ate_arr_a).reshape(-1)[0])
            lba = float(np.asarray(lb_arr_a).reshape(-1)[0])
            uba = float(np.asarray(ub_arr_a).reshape(-1)[0])
            sea = (uba - lba) / (2 * zcrit) if np.isfinite(uba) and np.isfinite(lba) else np.nan
            ates_all.append(aa)
            ses_all.append(sea)
        except Exception:
            pass

    # Inverse-variance combine across folds
    def ivw(ates, ses):
        ates = np.array([x for x in ates if np.isfinite(x)])
        ses = np.array([s for s in ses if np.isfinite(s) and s > 0])
        if ates.size == 0 or ses.size == 0 or ates.size != ses.size:
            return np.nan, np.nan, np.nan
        w = 1.0 / (ses ** 2)
        ate_hat = float(np.sum(w * ates) / np.sum(w))
        se_hat = float(np.sqrt(1.0 / np.sum(w)))
        lb_hat = ate_hat - zcrit * se_hat
        ub_hat = ate_hat + zcrit * se_hat
        return ate_hat, lb_hat, ub_hat

    ate_tv_hat, lb_tv_hat, ub_tv_hat = ivw(ates_tv, ses_tv)
    ate_hold_hat, lb_hold_hat, ub_hold_hat = ivw(ates_hold, ses_hold)
    ate_all_hat, lb_all_hat, ub_all_hat = ivw(ates_all, ses_all)

    rows.append({
        "estimator": est_name,
        "ATE_trainval": None if np.isnan(ate_tv_hat) else ate_tv_hat,
        "CI_trainval_lb": None if np.isnan(lb_tv_hat) else lb_tv_hat,
        "CI_trainval_ub": None if np.isnan(ub_tv_hat) else ub_tv_hat,
        "ATE_holdout": None if np.isnan(ate_hold_hat) else ate_hold_hat,
        "CI_holdout_lb": None if np.isnan(lb_hold_hat) else lb_hold_hat,
        "CI_holdout_ub": None if np.isnan(ub_hold_hat) else ub_hold_hat,
        "ATE_all": None if np.isnan(ate_all_hat) else ate_all_hat,
        "CI_all_lb": None if np.isnan(lb_all_hat) else lb_all_hat,
        "CI_all_ub": None if np.isnan(ub_all_hat) else ub_all_hat,
    })

est_stats_infer = pd.DataFrame(rows).set_index("estimator").sort_values(by="ATE_trainval", ascending=False)

# Append ensemble-avg row (mean of ATEs and SEs across estimators)
from scipy.stats import norm
alpha = 0.05
zcrit = norm.ppf(1 - alpha/2)

def _se_from_ci(lb, ub):
    return (np.asarray(ub, dtype=float) - np.asarray(lb, dtype=float)) / (2.0 * zcrit)

def _mean_ate_se(df, ate_col, lb_col, ub_col):
    if not all(c in df.columns for c in [ate_col, lb_col, ub_col]):
        return np.nan, np.nan
    means = np.asarray(pd.to_numeric(df[ate_col], errors="coerce"), dtype=float)
    lbs = np.asarray(pd.to_numeric(df[lb_col], errors="coerce"), dtype=float)
    ubs = np.asarray(pd.to_numeric(df[ub_col], errors="coerce"), dtype=float)
    ses = _se_from_ci(lbs, ubs)
    mask = np.isfinite(means) & np.isfinite(ses) & (ses > 0)
    if not mask.any():
        return np.nan, np.nan
    return float(np.mean(means[mask])), float(np.mean(ses[mask]))

ate_tv_m, se_tv_m = _mean_ate_se(est_stats_infer, "ATE_trainval", "CI_trainval_lb", "CI_trainval_ub")
ate_ho_m, se_ho_m = _mean_ate_se(est_stats_infer, "ATE_holdout", "CI_holdout_lb", "CI_holdout_ub")
ate_all_m, se_all_m = _mean_ate_se(est_stats_infer, "ATE_all", "CI_all_lb", "CI_all_ub")

row = {
    "ATE_trainval": None if np.isnan(ate_tv_m) else ate_tv_m,
    "CI_trainval_lb": None if np.isnan(ate_tv_m) or np.isnan(se_tv_m) else ate_tv_m - zcrit * se_tv_m,
    "CI_trainval_ub": None if np.isnan(ate_tv_m) or np.isnan(se_tv_m) else ate_tv_m + zcrit * se_tv_m,
    "ATE_holdout": None if np.isnan(ate_ho_m) else ate_ho_m,
    "CI_holdout_lb": None if np.isnan(ate_ho_m) or np.isnan(se_ho_m) else ate_ho_m - zcrit * se_ho_m,
    "CI_holdout_ub": None if np.isnan(ate_ho_m) or np.isnan(se_ho_m) else ate_ho_m + zcrit * se_ho_m,
    "ATE_all": None if np.isnan(ate_all_m) else ate_all_m,
    "CI_all_lb": None if np.isnan(ate_all_m) or np.isnan(se_all_m) else ate_all_m - zcrit * se_all_m,
    "CI_all_ub": None if np.isnan(ate_all_m) or np.isnan(se_all_m) else ate_all_m + zcrit * se_all_m,
}

est_stats_infer.loc["ENSEMBLE_AVG"] = [row.get(c, np.nan) for c in est_stats_infer.columns]
est_stats_infer




Unnamed: 0_level_0,ATE_trainval,CI_trainval_lb,CI_trainval_ub,ATE_holdout,CI_holdout_lb,CI_holdout_ub,ATE_all,CI_all_lb,CI_all_ub
estimator,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
t_lasso,0.041219,0.023493,0.058945,0.039589,0.003913,0.075265,0.040893,0.02502,0.056767
t_logistic,0.039431,0.021744,0.057117,0.037755,0.002182,0.073328,0.039095,0.023259,0.054931
x_lasso,0.038718,0.019978,0.057458,0.037267,-0.000178,0.074712,0.038417,0.021667,0.055167
ENSEMBLE_AVG,0.03979,0.021739,0.05784,0.038204,0.001972,0.074435,0.039468,0.023315,0.055622
