In [5]:
from econml.dml import CausalForestDML, LinearDML
from econml.dr import DRLearner
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np


def estimate_ate_econml(df, outcome, treatment, covariates, methods=["cf", "dr", "linear_dml", "ipw"]):
    X = df[covariates].values
    T = df[treatment].values
    Y = df[outcome].values

    results = {}

    # 1. Causal Forest DML
    if "cf" in methods:
        model_cf = CausalForestDML(
            model_y=RandomForestRegressor(),
            model_t=LogisticRegression(),
            discrete_treatment=True,
            random_state=42
        )
        model_cf.fit(Y, T, X=X)
        results["Causal Forest"] = model_cf.ate(X)

    # 2. Linear DML (doubly robust, linear final stage)
    if "linear_dml" in methods:
        model_linear = LinearDML(
            model_y=RandomForestRegressor(),
            model_t=LogisticRegression(),
            discrete_treatment=True,
            random_state=42
        )
        model_linear.fit(Y, T, X=X)
        results["Linear DML"] = model_linear.ate(X)

    # 3. DR Learner
    if "dr" in methods:
        model_dr = DRLearner(
            model_regression=RandomForestRegressor(),
            model_propensity=LogisticRegression(),
            discrete_treatment=True
        )
        model_dr.fit(Y, T, X=X)
        results["DR Learner"] = model_dr.ate(X)

    # 4. Manual IPW estimation
    if "ipw" in methods:
        model_prop = LogisticRegression()
        model_prop.fit(X, T)
        p = model_prop.predict_proba(X)[:, 1]
        p = np.clip(p, 1e-3, 1 - 1e-3)  # avoid division by 0
        weights = T / p + (1 - T) / (1 - p)
        ipw_ate = np.mean(weights * (T * Y / p - (1 - T) * Y / (1 - p)))
        results["IPW"] = ipw_ate

    return results


In [6]:
results = estimate_ate_econml(osrct_df, outcome="Diabetes_012", treatment="HighBP", covariates=covariates)
print("\n🎯 ATE Estimates (EconML):")
for method, val in results.items():
    print(f"{method}: {val:.4f}")


NameError: name 'osrct_df' is not defined