# OHIE Data Preparation – Stage 1

This notebook performs the *data preparation stage* for my PhD writing sample using the Oregon Health Insurance Experiment (OHIE).  
The goal here is **only** to clean and structure the data, not to run the causal ML or main econometric models.

Concretely, this notebook:

1. Loads the four OHIE source files (baseline 0m survey, 12m survey, descriptive/lottery list, and state program records).
2. Harmonizes identifiers and data types and merges to a person-level file.
3. Defines the instrument `Z` (lottery win) and treatment `W` (ever enrolled in Medicaid), plus the main outcome and covariate blocks.
4. Constructs an analysis sample (responded to both 0m and 12m).
5. Prefixes variables as `Y_` (outcomes) and `X_` (baseline covariates) so that the later causal ML code can treat them systematically.
6. Performs careful missing-data diagnostics and imputes covariates using an iterative tree-based imputer with missingness flags.
7. Constructs a numeric income measure and a catastrophic expenditure indicator.
8. Saves a clean intermediate dataset to disk for use in all downstream empirical work.


## 1. Setup

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

from statsmodels.stats.proportion import proportions_ztest
import statsmodels.api as sm

from sklearn.experimental import enable_iterative_imputer  # noqa: F401
from sklearn.impute import IterativeImputer
from sklearn.ensemble import ExtraTreesRegressor

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

pd.set_option("display.float_format", "{:,.3f}".format)

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

DATA_DIR = Path("../Data_Used")
assert DATA_DIR.exists(), f"DATA_DIR does not exist: {DATA_DIR.resolve()}"
DATA_DIR


## 2. Helper functions

In [None]:
def assert_unique_key(df: pd.DataFrame, key):
    """Assert that `key` is a unique row identifier in df."""
    if isinstance(key, str):
        key = [key]
    dup = df.duplicated(subset=key).sum()
    if dup > 0:
        raise ValueError(f"Key {key} is not unique: {dup} duplicate rows found.")
    return True


def summarize_missingness(df: pd.DataFrame, cols, name: str, max_rows: int = 30):
    """Print simple missingness summary for a list of columns."""
    cols = [c for c in cols if c in df.columns]
    miss = df[cols].isna().mean().sort_values(ascending=False)
    print(f"\nMissingness summary for {name} (share of rows that are NA):")
    display(miss.head(max_rows))


def quick_tabulate(df: pd.DataFrame, col: str):
    """Convenience wrapper for value_counts with NA visible."""
    print(f"\nValue counts for {col}:")
    display(df[col].value_counts(dropna=False).to_frame("n").assign(
        pct=lambda s: s["n"] / s["n"].sum()
    ))


## 3. Load 12‑month survey (outcomes)

In [None]:
# 12m survey – outcome block
col_keep_y = [
    # --- A. Key & filter variables ---
    "person_id",       # Merge key
    "returned_12m",    # 12m response indicator
    "weight_12m",      # Official OHIE 12m nonresponse weight

    # --- B. Primary financial outcomes ---
    # 1. Medical debt
    "cost_any_owe_12m",   # [0/1] Any medical debt
    "cost_tot_owe_12m",   # [Num] Total medical debt

    # 2. Financial distress
    "cost_borrow_12m",    # [0/1] Borrowed / skipped bills to pay medical expenses
    "cost_refused_12m",   # [0/1] Refused care due to non‑payment

    # 3. Out‑of‑pocket spending (aggregate + breakdown)
    "cost_tot_oop_12m",   # [Num] Total OOP over 12m
    "cost_any_oop_12m",   # [0/1] Any OOP spending
    "hhinc_cat_12m",      # [Cat] Household income category (used for catastrophic indicator)

    # Detailed OOP by service type
    "cost_doc_oop_12m",   # [Num] OOP on doctor visits
    "cost_er_oop_12m",    # [Num] OOP on ER visits
    "cost_rx_oop_12m",    # [Num] OOP on prescriptions
    "cost_oth_oop_12m",   # [Num] OOP on other care
]

y_path = DATA_DIR / "oregonhie_survey12m_vars.dta"
df_y = pd.read_stata(
    y_path,
    columns=col_keep_y,
    convert_categoricals=False,
    preserve_dtypes=True,
)

if df_y["person_id"].isna().any():
    raise ValueError("df_y has missing person_id values.")

df_y["person_id"] = df_y["person_id"].astype("int64")
assert_unique_key(df_y, "person_id")

print("df_y shape:", df_y.shape)
df_y.head()


## 4. Load baseline 0m survey (covariates)

In [None]:
# 0m survey – baseline covariate block
col_keep_x = [
    # --- A. Key & filter ---
    "person_id",        # Merge key
    "returned_0m",      # Baseline survey response
    "surv_lang_0m",     # Survey language

    # --- B. Baseline health need & utilization ---
    "needmet_med_0m",   # [0/1] Got all needed medical care (unmet need)
    "needmet_rx_0m",    # [0/1] Got all needed medications
    "need_rx_0m",       # [0/1] Needed medication
    "need_med_0m",      # [0/1] Needed medical care

    "rx_num_mod_0m",    # [Num] # of distinct prescriptions
    "doc_num_mod_0m",   # [Num] # doctor visits last 6m
    "er_num_mod_0m",    # [Num] # ER visits last 6m
    "hosp_num_mod_0m",  # [Num] # hospitalizations last 6m

    "ins_months_0m",    # [0–6] # months insured in last 6m

    # --- C. Baseline health status ---
    "health_gen_0m",    # [1–5] Self‑rated health
    "baddays_phys_0m",  # [0–30] Days physical health "not good"
    "baddays_ment_0m",  # [0–30] Days mental health "not good"
    "health_chg_0m",    # [1–3] Health better/same/worse vs 1 year ago

    # --- D. Diagnosed conditions ---
    "dia_dx_0m",        # [0/1] Diabetes
    "ast_dx_0m",        # [0/1] Asthma
    "hbp_dx_0m",        # [0/1] Hypertension
    "emp_dx_0m",        # [0/1] Emphysema/COPD
    "chf_dx_0m",        # [0/1] Congestive heart failure
    "dep_dx_0m",        # [0/1] Depression / anxiety

    # --- E. Demographics ---
    "female_0m",        # [0/1] Female
    "birthyear_0m",     # [Year] Birth year (used for baseline age)
    "edu_0m",           # [Cat] Education

    # --- F. Race & ethnicity ---
    "race_hisp_0m",     # [0/1] Hispanic
    "race_white_0m",    # [0/1] White
    "race_black_0m",    # [0/1] Black
    "race_amerindian_0m",  # [0/1] American Indian / Alaska Native
    "race_asian_0m",       # [0/1] Asian
    "race_pacific_0m",     # [0/1] Native Hawaiian / Pacific Islander
    "race_other_qn_0m",    # [0/1] Other race

    # --- G. Employment & household ---
    "employ_0m",        # [0/1] Employed
    "employ_hrs_0m",    # [Cat] Hours worked / week
    "hhinc_cat_0m",     # [Cat] Baseline HH income category
    "hhsize_0m",        # [Num] Household size
    "num19_0m",         # [Num] # children under 19

    # --- H. Baseline financial variables ---
    "cost_any_oop_0m",  # [0/1] Any OOP at baseline
    "cost_borrow_0m",   # [0/1] Borrowed / skipped bills for medical care
    "cost_any_owe_0m",  # [0/1] Owe any medical debt
    "cost_tot_owe_0m",  # [Num] Total baseline medical debt
    "cost_refused_0m",  # [0/1] Refused care for non‑payment
    # If `cost_tot_oop_0m` exists, we will use it to build a cleaned OOP measure.
    # Otherwise, we will back out total from any available OOP components.
    "cost_tot_oop_0m",

    # --- I. Location (from 0m / list) ---
    "zip_msa_list",     # [0/1] MSA indicator (urban vs rural)
]

x_path = DATA_DIR / "oregonhie_survey0m_vars.dta"
df_x = pd.read_stata(
    x_path,
    columns=[c for c in col_keep_x if c is not None],
    convert_categoricals=False,
    preserve_dtypes=True,
)

if df_x["person_id"].isna().any():
    raise ValueError("df_x has missing person_id values.")

df_x["person_id"] = df_x["person_id"].astype("int64")
assert_unique_key(df_x, "person_id")

# Construct a cleaned total OOP measure at baseline.
if "cost_tot_oop_0m" in df_x.columns:
    df_x["cost_tot_oop_correct_0m"] = df_x["cost_tot_oop_0m"]
else:
    # Fallback: if detailed OOP components exist (rare), sum them.
    oop_components = [c for c in [
        "cost_doc_oop_0m", "cost_er_oop_0m", "cost_rx_oop_0m", "cost_oth_oop_0m"
    ] if c in df_x.columns]
    if oop_components:
        df_x["cost_tot_oop_correct_0m"] = df_x[oop_components].sum(axis=1)
    else:
        df_x["cost_tot_oop_correct_0m"] = np.nan

print("df_x shape:", df_x.shape)
df_x.head()


## 5. Load descriptive / lottery list file (instrument)

In [None]:
col_keep_iv = [
    "person_id",
    "household_id",
    "treatment",      # [0/1] Lottery win (instrument Z)
    "numhh_list",     # [Num] # of people on lottery list in household
    "zip_msa_list",   # [0/1] MSA indicator (urban vs rural)
    "female_list",    # [0/1] Gender at sign‑up (backup)
    "birthyear_list", # [Year] Birth year at sign‑up (backup)
]

iv_path = DATA_DIR / "oregonhie_descriptive_vars.dta"
df_iv = pd.read_stata(
    iv_path,
    columns=col_keep_iv,
    convert_categoricals=False,
    preserve_dtypes=True,
)

if df_iv["person_id"].isna().any():
    raise ValueError("df_iv has missing person_id values.")
if df_iv["household_id"].isna().any():
    raise ValueError("df_iv has missing household_id values.")

df_iv["person_id"] = df_iv["person_id"].astype("int64")
df_iv["household_id"] = df_iv["household_id"].astype("int64")
assert_unique_key(df_iv, "person_id")

print("df_iv shape:", df_iv.shape)
df_iv.head()


## 6. Load state program file (treatment)

In [None]:
col_keep_w = [
    "person_id",
    # Ever enrolled in Medicaid (OHP Standard) by 30 Sept 2009 – main treatment
    "ohp_all_ever_firstn_30sep2009",
    # Months enrolled (intensity)
    "ohp_all_mo_firstn_30sep2009",
]

w_path = DATA_DIR / "oregonhie_stateprograms_vars.dta"
df_w = pd.read_stata(
    w_path,
    columns=col_keep_w,
    convert_categoricals=False,
    preserve_dtypes=True,
)

if df_w["person_id"].isna().any():
    raise ValueError("df_w has missing person_id values.")

df_w["person_id"] = df_w["person_id"].astype("int64")
assert_unique_key(df_w, "person_id")

print("df_w shape:", df_w.shape)
df_w.head()


## 7. Merge all sources and define instrument (Z) and treatment (W)

In [None]:
# Merge starting from the universe of lottery applicants (df_iv)
df_merged = (
    df_iv
    .merge(df_x, on="person_id", how="left", validate="1:1")
    .merge(df_y, on="person_id", how="left", validate="1:1")
    .merge(df_w, on="person_id", how="left", validate="1:1")
)

print(f"After merge: {df_merged.shape[0]:,} rows × {df_merged.shape[1]} columns")

# Rename to follow econometric convention (Z,W)
df_merged = df_merged.rename(columns={
    "treatment": "Z_lottery",
    "ohp_all_ever_firstn_30sep2009": "W_medicaid",
    "ohp_all_mo_firstn_30sep2009": "W_medicaid_months",
})

# People who never enrolled have missing W_medicaid; treat as 0
df_merged["W_medicaid"] = df_merged["W_medicaid"].fillna(0).astype("int8")
df_merged["W_medicaid_months"] = df_merged["W_medicaid_months"].fillna(0)

quick_tabulate(df_merged, "Z_lottery")
quick_tabulate(df_merged, "W_medicaid")

# Quick check: Z must be binary and defined for everyone in the lottery list
unique_z = sorted(df_merged["Z_lottery"].dropna().unique())
print("\nUnique values of Z_lottery:", unique_z)
assert set(unique_z).issubset({0, 1}), "Z_lottery is not binary"


## 8. Attrition analysis and use of 12m weights

In [None]:
print("\n=== ATTRITION ANALYSIS & WEIGHT DIAGNOSTICS ===")

# Response by instrument status
attrition_table = df_merged.groupby("Z_lottery")["returned_12m"].agg(["mean", "count", "sum"])
attrition_table.columns = ["Response_Rate", "Total_N", "Responders_N"]
print("\n1. Response rates by lottery status:")
display(attrition_table)

# Two-sample z-test for equality of response rates across Z
z_stat, p_attrition = proportions_ztest(
    count=attrition_table["Responders_N"],
    nobs=attrition_table["Total_N"],
)
diff_pp = (attrition_table.loc[1, "Response_Rate"] - attrition_table.loc[0, "Response_Rate"]) * 100
print(f"\nDifference in response rates (Z=1 − Z=0): {diff_pp:+.2f} pp")
print(f"z = {z_stat:.3f}, p = {p_attrition:.4f}")

# Construct a normalized nonresponse weight based on the official weight_12m
df_merged["weight_attrition"] = np.where(
    df_merged["returned_12m"] == 1,
    df_merged["weight_12m"],
    np.nan,
)
mean_w = df_merged["weight_attrition"].mean(skipna=True)
df_merged["weight_attrition"] = df_merged["weight_attrition"] / mean_w

print("\n2. Distribution of attrition weights among 12m respondents:")
weight_stats = (
    df_merged.loc[df_merged["returned_12m"] == 1]
    .groupby("Z_lottery")["weight_attrition"]
    .agg(["mean", "min", "max"])
)
display(weight_stats)

from scipy.stats import ttest_ind

resp = df_merged["returned_12m"] == 1
w_z1 = df_merged.loc[resp & (df_merged["Z_lottery"] == 1), "weight_attrition"]
w_z0 = df_merged.loc[resp & (df_merged["Z_lottery"] == 0), "weight_attrition"]
t_stat, p_weight = ttest_ind(w_z1, w_z0, alternative="greater")
print(f"\nTest of higher weights for Z=1 vs Z=0 among respondents: p = {p_weight:.4f}")
print(f"Z=1 mean weight: {w_z1.mean():.2f}, Z=0 mean weight: {w_z0.mean():.2f}")

# Extreme-weight diagnostics
extreme_thresh = 5
extreme_pct = (
    df_merged.loc[resp, "weight_attrition"].gt(extreme_thresh).mean() * 100
)
print(f"\n3. Extreme weights (> {extreme_thresh}): {extreme_pct:.2f}% of 12m respondents")

# Optional: simple visualization (for interactive use)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
sns.histplot(
    data=df_merged.loc[resp],
    x="weight_attrition",
    hue="Z_lottery",
    element="step",
    common_norm=False,
    ax=ax1,
)
ax1.set_title("Weight distribution by lottery status (respondents)")

sns.boxplot(
    data=df_merged.loc[resp],
    x="Z_lottery",
    y="weight_attrition",
    ax=ax2,
)
ax2.set_xticklabels(["Lost", "Won"])
ax2.set_title("Weight outliers")
plt.tight_layout()
plt.show()

attrition_diagnostic = {
    "response_rate_diff_pp": float(diff_pp),
    "attrition_p_value": float(p_attrition),
    "weight_mean_z1": float(w_z1.mean()),
    "weight_mean_z0": float(w_z0.mean()),
    "extreme_weight_pct": float(extreme_pct),
}
attrition_diagnostic


## 9. Define analysis sample and main X/Y blocks

In [None]:
# Keep individuals who answered both baseline (0m) and 12m surveys
df_analysis = df_merged.loc[
    (df_merged["returned_0m"] == 1) &
    (df_merged["returned_12m"] == 1)
].copy()

print(
    f"Analysis sample (responded to 0m and 12m): "
    f"{df_analysis.shape[0]:,} rows × {df_analysis.shape[1]} columns"
)

# Main outcome block (raw column names before prefixing)
Y_raw = [
    "cost_any_owe_12m",
    "cost_tot_owe_12m",
    "cost_borrow_12m",
    "cost_refused_12m",
    "cost_tot_oop_12m",
    "cost_any_oop_12m",
    "hhinc_cat_12m",
    "cost_doc_oop_12m",
    "cost_er_oop_12m",
    "cost_rx_oop_12m",
    "cost_oth_oop_12m",
]

# Main covariate block (baseline + lottery‑list covariates)
X_raw = [
    "surv_lang_0m",
    "needmet_med_0m",
    "needmet_rx_0m",
    "need_rx_0m",
    "need_med_0m",
    "rx_num_mod_0m",
    "doc_num_mod_0m",
    "er_num_mod_0m",
    "hosp_num_0m",
    "hosp_num_mod_0m",   # keep both in case only one is present
    "ins_months_0m",
    "health_gen_0m",
    "baddays_phys_0m",
    "baddays_ment_0m",
    "health_chg_0m",
    "dia_dx_0m",
    "ast_dx_0m",
    "hbp_dx_0m",
    "emp_dx_0m",
    "chf_dx_0m",
    "dep_dx_0m",
    "female_0m",
    "birthyear_0m",
    "edu_0m",
    "race_hisp_0m",
    "race_white_0m",
    "race_black_0m",
    "race_amerindian_0m",
    "race_asian_0m",
    "race_pacific_0m",
    "race_other_qn_0m",
    "employ_0m",
    "employ_hrs_0m",
    "hhinc_cat_0m",
    "hhsize_0m",
    "num19_0m",
    "cost_tot_owe_0m",
    "cost_borrow_0m",
    "cost_any_owe_0m",
    "cost_refused_0m",
    "cost_any_oop_0m",
    "cost_tot_oop_correct_0m",
    "zip_msa_list",
]

# Keep only the X/Y columns that are actually present
Y_raw = [c for c in Y_raw if c in df_analysis.columns]
X_raw = [c for c in X_raw if c in df_analysis.columns]

print("\nIncluded Y columns:", Y_raw)
print("\nIncluded X columns:", X_raw)

summarize_missingness(df_analysis, Y_raw, "Y (outcomes)")
summarize_missingness(df_analysis, X_raw, "X (baseline covariates)")


## 10. Prefix X_ and Y_ variable names

In [None]:
rename_map = {}

for col in X_raw:
    rename_map[col] = f"X_{col}"

for col in Y_raw:
    rename_map[col] = f"Y_{col}"

df_analysis_sample = df_analysis.rename(columns=rename_map).copy()

X_cols = [f"X_{c}" for c in X_raw]
Y_cols = [f"Y_{c}" for c in Y_raw]

print(f"After renaming, analysis sample: {df_analysis_sample.shape[0]:,} rows")
print("\nFirst few X variables:", X_cols[:10])
print("First few Y variables:", Y_cols[:10])


## 11. Income mapping and catastrophic expenditure indicator

In [None]:
# Map OHIE 12m income categories to numeric midpoints (2008 dollars)
income_map_12m = {
    1: 0,       # "$0"
    2: 1250,    # "$1 to $2,500"
    3: 3750,    # "$2,501 to $5,000"
    4: 6250,    # "$5,001 to $7,500"
    5: 8750,    # "$7,501 to $10,000"
    6: 11250,   # "$10,001 to $12,500"
    7: 13750,   # "$12,501 to $15,000"
    8: 16250,   # "$15,001 to $17,500"
    9: 18750,   # "$17,501 to $20,000"
    10: 21250,  # "$20,001 to $22,500"
    11: 23750,  # "$22,501 to $25,000"
    12: 26250,  # "$25,001 to $27,500"
    13: 28750,  # "$27,501 to $30,000"
    14: 31250,  # "$30,001 to $32,500"
    15: 33750,  # "$32,501 to $35,000"
    16: 36250,  # "$35,001 to $37,500"
    17: 38750,  # "$37,501 to $40,000"
    18: 41250,  # "$40,001 to $42,500"
    19: 43750,  # "$42,501 to $45,000"
    20: 46250,  # "$45,001 to $47,500"
    21: 48750,  # "$47,501 to $50,000"
    22: 60000,  # "$50,001 or more" (upper bin; exact value not important for our purposes)
}

if "Y_hhinc_cat_12m" in df_analysis_sample.columns:
    df_analysis_sample["Y_income_num_12m"] = df_analysis_sample["Y_hhinc_cat_12m"].map(income_map_12m)
    summarize_missingness(df_analysis_sample, ["Y_income_num_12m"], "numeric 12m income")

# Catastrophic expenditure: OOP > 30% of income (with careful handling of zero / missing income)
def calc_catastrophic(row):
    income = row.get("Y_income_num_12m", np.nan)
    oop = row.get("Y_cost_tot_oop_12m", np.nan)

    if pd.isna(income) or pd.isna(oop):
        return np.nan
    if income == 0:
        return 1 if oop > 0 else 0
    return 1 if (oop / income) > 0.30 else 0

df_analysis_sample["Y_catastrophic_exp_12m"] = df_analysis_sample.apply(calc_catastrophic, axis=1)
print("\nCatastrophic expenditure rate (non-missing):",
      df_analysis_sample["Y_catastrophic_exp_12m"].mean(skipna=True))


## 12. Baseline age construction

In [None]:
# Construct baseline age (using 2008 as reference year)
if "X_birthyear_0m" in df_analysis_sample.columns:
    df_analysis_sample["X_age_0m"] = 2008 - df_analysis_sample["X_birthyear_0m"]
    print("Baseline age range:",
          df_analysis_sample["X_age_0m"].min(),
          "to",
          df_analysis_sample["X_age_0m"].max())
    # Keep birthyear only for potential robustness; comment out the next line to retain it
    df_analysis_sample = df_analysis_sample.drop(columns=["X_birthyear_0m"])


## 13. Missing-data handling and imputation for X covariates

In [None]:
# Schema for baseline covariates (after X_ prefix)
col_schema = {
    "binary": [
        "X_need_med_0m", "X_needmet_med_0m", "X_need_rx_0m", "X_needmet_rx_0m",
        "X_dia_dx_0m", "X_ast_dx_0m", "X_hbp_dx_0m", "X_emp_dx_0m", "X_chf_dx_0m",
        "X_dep_dx_0m", "X_female_0m", "X_employ_0m", "X_zip_msa_list",
        "X_race_hisp_0m", "X_race_white_0m", "X_race_black_0m", "X_race_amerindian_0m",
        "X_race_asian_0m", "X_race_pacific_0m", "X_race_other_qn_0m",
        "X_cost_borrow_0m", "X_cost_any_owe_0m", "X_cost_refused_0m", "X_cost_any_oop_0m",
    ],
    "count": [
        "X_rx_num_mod_0m", "X_doc_num_mod_0m", "X_er_num_mod_0m", "X_hosp_num_mod_0m",
        "X_hhsize_0m", "X_num19_0m",
    ],
    "ordinal": [
        "X_surv_lang_0m", "X_health_gen_0m", "X_health_chg_0m",
        "X_hhinc_cat_0m", "X_edu_0m", "X_employ_hrs_0m",
        "X_baddays_phys_0m", "X_baddays_ment_0m", "X_ins_months_0m",
    ],
    "continuous": [
        "X_cost_tot_owe_0m", "X_cost_tot_oop_correct_0m",
    ],
}

# Keep only schema variables that are actually present in the analysis sample
for k in col_schema:
    col_schema[k] = [c for c in col_schema[k] if c in df_analysis_sample.columns]

all_x_cols = col_schema["binary"] + col_schema["count"] + col_schema["ordinal"] + col_schema["continuous"]
all_x_cols = [c for c in all_x_cols if c in df_analysis_sample.columns]

print(f"Total X covariates in schema: {len(all_x_cols)}")
summarize_missingness(df_analysis_sample, all_x_cols, "X (schema vars)")

# Columns that actually need imputation (have any missing values)
cols_to_impute = [c for c in all_x_cols if df_analysis_sample[c].isna().any()]
print(f"\nColumns with missingness to impute: {len(cols_to_impute)}")

missing_flag_cols = []
for col in cols_to_impute:
    flag_col = f"{col}_missing"
    df_analysis_sample[flag_col] = df_analysis_sample[col].isna().astype("int8")
    missing_flag_cols.append(flag_col)
print(f"Total missing-flag columns created: {len(missing_flag_cols)}")

# Keep original ranges of ordinal variables for post-imputation sanity checks
orig_minmax = {
    col: (
        df_analysis_sample[col].min(skipna=True),
        df_analysis_sample[col].max(skipna=True),
    )
    for col in col_schema["ordinal"] if col in cols_to_impute
}

# Configure tree-based imputer
ets = ExtraTreesRegressor(
    n_estimators=200,
    max_depth=None,
    min_samples_leaf=2,
    max_features="sqrt",
    random_state=RANDOM_SEED,
    n_jobs=-1,
)

imputer = IterativeImputer(
    estimator=ets,
    max_iter=2,
    tol=1e-3,
    initial_strategy="median",
    imputation_order="ascending",
    add_indicator=False,  # we created our own flags
    random_state=RANDOM_SEED,
    verbose=0,
)

# Fit on full covariate context (schema vars only; outcome-free)
clean_x_cols = [c for c in all_x_cols if c not in cols_to_impute]
cols_for_model = clean_x_cols + cols_to_impute

X_full_context = df_analysis_sample[cols_for_model].copy()
imputer.fit(X_full_context)
X_imputed_array = imputer.transform(X_full_context)

X_imputed = pd.DataFrame(
    X_imputed_array,
    columns=cols_for_model,
    index=df_analysis_sample.index,
)

# Overwrite only the columns we chose to impute
df_analysis_sample[cols_to_impute] = X_imputed[cols_to_impute]

# Clip continuous variables to [1st, 99th] percentile to tame extreme imputations
for col in col_schema["continuous"]:
    if col in cols_to_impute:
        low_cap = df_analysis_sample[col].quantile(0.01)
        high_cap = df_analysis_sample[col].quantile(0.99)
        df_analysis_sample[col] = df_analysis_sample[col].clip(lower=low_cap, upper=high_cap)
        print(f"Clipped {col} to [{low_cap:.2f}, {high_cap:.2f}]")

# Sanity checks
remaining_nans = df_analysis_sample[all_x_cols].isna().sum().sum()
print(f"\nRemaining missing values in X schema vars: {remaining_nans}")
assert remaining_nans == 0, "Imputation failed – some X covariates still contain NA."

for col in missing_flag_cols:
    unique_vals = set(df_analysis_sample[col].dropna().unique())
    if not unique_vals.issubset({0, 1}):
        raise ValueError(f"Flag {col} has non-binary values: {unique_vals}")

print(
    f"Imputation successful: {df_analysis_sample.shape[0]} rows, "
    f"{len(all_x_cols)} covariates + {len(missing_flag_cols)} missingness flags"
)

final_covariates = all_x_cols + missing_flag_cols
len(final_covariates)


## 14. Final dataset and save to disk

In [None]:
# Columns needed for downstream causal ML / econometrics
core_cols = [
    "person_id",
    "household_id",
    "Z_lottery",
    "W_medicaid",
    "W_medicaid_months",
    "weight_12m",
    "weight_attrition",
]

# Ensure we only keep columns that exist
core_cols = [c for c in core_cols if c in df_analysis_sample.columns]

final_y_cols = [c for c in Y_cols if c in df_analysis_sample.columns] + [
    "Y_income_num_12m",
    "Y_catastrophic_exp_12m",
]

final_x_cols = [c for c in final_covariates if c in df_analysis_sample.columns]

all_keep = core_cols + final_y_cols + final_x_cols
all_keep = list(dict.fromkeys(all_keep))  # de-duplicate while preserving order

df_final = df_analysis_sample[all_keep].copy()
print(
    f"Final prepared dataset: {df_final.shape[0]:,} rows × {df_final.shape[1]} columns"
)

out_path = DATA_DIR / "ohie_full_intermediate_dataset.feather"
df_final.to_feather(out_path)
out_path
