In [10]:
import numpy as np
from typing import List, Dict, Any

# adjust import to your package/module path
from causalkit.data import CausalDatasetGenerator

confounder_specs: List[Dict[str, Any]] = [
    {"name": "tenure_months",     "dist": "normal",   "mu": 24, "sd": 12},
    {"name": "avg_sessions_week", "dist": "normal",   "mu": 5,  "sd": 2},
    {"name": "spend_last_month",  "dist": "uniform",  "a": 0,   "b": 200},
    {"name": "premium_user",      "dist": "bernoulli","p": 0.25},
    {"name": "urban_resident",    "dist": "bernoulli","p": 0.60},
]

# Moderate, sensible effects by column name (linear, well-specified)
# Outcome: higher sessions, tenure, spend, premium, urban -> higher Y
beta_y_map = {
    "tenure_months":     0.05,   # ~0.6 SD shift at +1 SD (12 months)
    "avg_sessions_week": 0.60,   # strong engagement signal
    "spend_last_month":  0.005,  # scale 0..200 => up to ~1 shift
    "premium_user":      0.80,
    "urban_resident":    0.20,
}

# Treatment score: moderate dependence on engagement, spend, premium, urban
beta_d_map = {
    "tenure_months":     0.08,
    "avg_sessions_week": 0.12,
    "spend_last_month":  0.004,
    "premium_user":      0.25,
    "urban_resident":    0.10,
}

def expand_beta_from_specs(specs: List[Dict[str, Any]], beta_map: Dict[str, float]) -> np.ndarray:
    """Create β aligned to the generator's X column order from confounder_specs."""
    betas = []
    for spec in specs:
        name = spec.get("name", "")
        dist = str(spec.get("dist", "normal")).lower()
        if dist in ("normal", "uniform", "bernoulli"):
            betas.append(beta_map.get(name, 0.0))
        else:
            raise ValueError(f"Unsupported dist in this simple setup: {dist}")
    return np.asarray(betas, dtype=float)

beta_y = expand_beta_from_specs(confounder_specs, beta_y_map)
beta_d = expand_beta_from_specs(confounder_specs, beta_d_map)

gen = CausalDatasetGenerator(
    theta=0.80,                 # constant treatment effect
    tau=None,                   # use theta
    beta_y=beta_y,
    beta_d=beta_d,
    g_y=None, g_d=None,         # no nonlinearities
    alpha_y=0.0,
    alpha_d=0.0,
    sigma_y=1.0,
    outcome_type="continuous",  # Gaussian Y
    confounder_specs=confounder_specs,
    # IRM-EASY: no unobserved confounding, standard sharpness
    u_strength_d=0.0,
    u_strength_y=0.0,
    propensity_sharpness=1.0,
    # keep overall treatment share ~ 0.5 even with imbalanced features
    target_d_rate=0.20,
    seed=123
)

n = 100_000
df = gen.generate(n)

print("Treatment share ≈", df["d"].mean())        # ~0.50
print(df.filter(regex=r"^(g0|g1|cate)$").head())  # cate == 0.8 everywhere
# Columns include: y, d, tenure_months, avg_sessions_week, spend_last_month,
#                  premium_user, urban_resident, m, g0, g1, cate, ...


Treatment share ≈ 0.19986
         g0        g1  cate
0  4.439569  5.239569   0.8
1  2.914083  3.714083   0.8
2  3.190812  3.990812   0.8
3  2.681789  3.481789   0.8
4  3.813471  4.613471   0.8


In [11]:
from causalkit.data import CausalData
# Ground-truth ATT (on the natural scale): E[tau(X) | T=1] = mean CATE among the treated
true_att = float(df.loc[df["d"] == 1, "cate"].mean())
print(f"Ground-truth ATT from the DGP: {true_att:.3f}")

# 8) Wrap as CausalData for downstream workflows (keeps only y, t, and specified confounders)
causal_data = CausalData(
    df=df,
    treatment="d",
    outcome="y",
    confounders=["tenure_months",
                 "avg_sessions_week",
                 "spend_last_month",
                 "premium_user",
                 "premium_user",
                 "urban_resident"]
)

# Peek at the analysis-ready view
causal_data.df.head()

Ground-truth ATT from the DGP: 0.800


Unnamed: 0,y,d,tenure_months,avg_sessions_week,spend_last_month,premium_user,urban_resident
0,4.431598,0.0,12.130544,5.803342,30.207326,0.0,1.0
1,3.089098,0.0,19.58656,2.320295,68.515624,0.0,1.0
2,3.320877,0.0,39.455103,1.351509,41.43027,0.0,1.0
3,2.212806,0.0,26.327693,1.459923,97.89003,0.0,0.0
4,3.919363,0.0,35.042771,1.938168,139.686209,0.0,1.0


In [12]:
from causalkit.inference.ate import dml_ate

# Estimate Average Treatment Effect (ATT)
att_result = dml_ate(causal_data, n_folds=4)
att_result

{'coefficient': 0.8002447624692177,
 'std_error': 0.012660407668213813,
 'p_value': 0.0,
 'confidence_interval': (0.7754308194099239, 0.8250587055285116),
 'model': <causalkit.inference.estimators.irm.IRM at 0x158ae8f50>,
 'diagnostic_data': {'m_hat': array([0.03536389, 0.06605769, 0.23265153, ..., 0.33150187, 0.07472227,
         0.05598924], shape=(100000,)),
  'g0_hat': array([4.4493435 , 2.93265171, 3.25461208, ..., 6.17353241, 4.13386574,
         3.4949742 ], shape=(100000,)),
  'g1_hat': array([4.8247974 , 3.40419783, 4.13266757, ..., 6.68172359, 5.1071098 ,
         3.84294347], shape=(100000,)),
  'y': array([4.43159809, 3.08909757, 3.32087743, ..., 7.70334644, 5.56867371,
         2.46455984], shape=(100000,)),
  'd': array([0, 0, 0, ..., 1, 1, 0], shape=(100000,)),
  'x': array([[ 12.1305438 ,   5.80334174,  30.20732587,   0.        ,
            1.        ],
         [ 19.58656018,   2.32029504,  68.51562396,   0.        ,
            1.        ],
         [ 39.45510314,   

In [13]:
from causalkit.inference.atte import dml_atte

# Estimate Average Treatment Effect (ATT)
atte_result = dml_atte(causal_data, n_folds=4)
atte_result

{'coefficient': 0.7983803967140753,
 'std_error': 0.00990249789092322,
 'p_value': 0.0,
 'confidence_interval': (0.7789718574908819, 0.8177889359372688),
 'model': <causalkit.inference.estimators.irm.IRM at 0x158ae91d0>,
 'diagnostic_data': {'m_hat': array([0.05857877, 0.08156973, 0.22005608, ..., 0.35261222, 0.07122546,
         0.0571093 ], shape=(100000,)),
  'g0_hat': array([4.70544298, 2.99663398, 3.20418668, ..., 6.09139723, 4.27748497,
         3.33847651], shape=(100000,)),
  'g1_hat': array([4.74376967, 3.56687562, 4.02049963, ..., 6.68450304, 5.15821861,
         4.27247121], shape=(100000,)),
  'y': array([4.43159809, 3.08909757, 3.32087743, ..., 7.70334644, 5.56867371,
         2.46455984], shape=(100000,)),
  'd': array([0, 0, 0, ..., 1, 1, 0], shape=(100000,)),
  'x': array([[ 12.1305438 ,   5.80334174,  30.20732587,   0.        ,
            1.        ],
         [ 19.58656018,   2.32029504,  68.51562396,   0.        ,
            1.        ],
         [ 39.45510314,   1