In [23]:
from __future__ import annotations
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import Callable, List, Optional, Dict, Any, Union

def _sigmoid(z):
    return 1 / (1 + np.exp(-z))

@dataclass
class CausalDatasetGenerator:
    # Core knobs
    theta: float = 1.0                            # constant treatment effect (ATE) if tau is None
    tau: Optional[Callable[[np.ndarray], np.ndarray]] = None  # heterogeneous effect tau(X) if provided

    # Confounder -> outcome/treatment effects
    beta_y: Optional[np.ndarray] = None           # shape (k,)
    beta_t: Optional[np.ndarray] = None           # shape (k,)
    g_y: Optional[Callable[[np.ndarray], np.ndarray]] = None  # nonlinear baseline outcome f_y(X)
    g_t: Optional[Callable[[np.ndarray], np.ndarray]] = None  # nonlinear treatment score f_t(X)

    # Outcome/treatment intercepts and noise
    alpha_y: float = 0.0
    alpha_t: float = 0.0
    sigma_y: float = 1.0                          # used when outcome_type="continuous"
    outcome_type: str = "continuous"              # "continuous" | "binary" | "poisson"

    # Confounder generation
    confounder_specs: Optional[List[Dict[str, Any]]] = None   # list of {"name","dist",...}
    k: int = 5                                    # used if confounder_specs is None
    x_sampler: Optional[Callable[[int, int, int], np.ndarray]] = None  # custom sampler (n, k, seed)->X

    # Practical controls
    target_t_rate: Optional[float] = None         # e.g., 0.3 -> ~30% treated; solves for alpha_t
    u_strength_t: float = 0.0                     # unobserved confounder effect on treatment
    u_strength_y: float = 0.0                     # unobserved confounder effect on outcome
    seed: Optional[int] = None

    # Internals (filled post-init)
    rng: np.random.Generator = field(init=False, repr=False)

    def __post_init__(self):
        self.rng = np.random.default_rng(self.seed)
        if self.confounder_specs is not None:
            self.k = len(self.confounder_specs)

    # ---------- Confounder sampling ----------

    def _sample_X(self, n: int) -> (np.ndarray, List[str]):
        if self.x_sampler is not None:
            X = self.x_sampler(n, self.k, self.seed)
            names = [f"x{i+1}" for i in range(self.k)]
            return X, names

        if self.confounder_specs is None:
            # Default: independent standard normals
            X = self.rng.normal(size=(n, self.k))
            names = [f"x{i+1}" for i in range(self.k)]
            return X, names

        cols = []
        names = []
        for spec in self.confounder_specs:
            name = spec.get("name") or f"x{len(names)+1}"
            dist = spec.get("dist", "normal").lower()
            if dist == "normal":
                mu = spec.get("mu", 0.0); sd = spec.get("sd", 1.0)
                col = self.rng.normal(mu, sd, size=n)
            elif dist == "uniform":
                a = spec.get("a", 0.0); b = spec.get("b", 1.0)
                col = self.rng.uniform(a, b, size=n)
            elif dist == "bernoulli":
                p = spec.get("p", 0.5)
                col = self.rng.binomial(1, p, size=n).astype(float)
            elif dist == "categorical":
                categories = spec.get("categories", [0,1,2])
                probs = spec.get("probs", None)
                col = self.rng.choice(categories, p=probs, size=n)
                # one-hot encode (except first level)
                oh = [ (col == c).astype(float) for c in categories[1:] ]
                if not oh:
                    oh = [np.zeros(n)]
                for j, c in enumerate(categories[1:]):
                    cols.append(oh[j])
                    names.append(f"{name}_{c}")
                continue
            else:
                raise ValueError(f"Unknown dist: {dist}")
            cols.append(col.astype(float))
            names.append(name)
        X = np.column_stack(cols) if cols else np.empty((n,0))
        self.k = X.shape[1]
        return X, names

    # ---------- Helpers ----------

    def _treatment_score(self, X: np.ndarray, U: np.ndarray) -> np.ndarray:
        lin = np.zeros(X.shape[0])
        if self.beta_t is not None:
            lin += X @ self.beta_t
        if self.g_t is not None:
            lin += self.g_t(X)
        if self.u_strength_t != 0:
            lin += self.u_strength_t * U
        return lin

    def _outcome_location(self, X: np.ndarray, T: np.ndarray, U: np.ndarray, tau_x: np.ndarray) -> np.ndarray:
        # location on natural scale for continuous; on logit/log scale for binary/poisson
        loc = self.alpha_y
        if self.beta_y is not None:
            loc += X @ self.beta_y
        if self.g_y is not None:
            loc += self.g_y(X)
        if self.u_strength_y != 0:
            loc += self.u_strength_y * U
        loc += T * tau_x
        return loc

    def _calibrate_alpha_t(self, X: np.ndarray, U: np.ndarray, target: float) -> float:
        # Bisection on alpha_t so that mean propensity ~ outcome
        lo, hi = -15.0, 15.0
        for _ in range(60):
            mid = 0.5*(lo+hi)
            p = _sigmoid(mid + self._treatment_score(X, U))
            m = p.mean()
            if m > target:
                hi = mid
            else:
                lo = mid
        return 0.5*(lo+hi)

    # ---------- Public API ----------

    def generate(self, n: int) -> pd.DataFrame:
        X, names = self._sample_X(n)
        U = self.rng.normal(size=n)  # unobserved confounder

        # Treatment assignment
        if self.target_t_rate is not None:
            self.alpha_t = self._calibrate_alpha_t(X, U, self.target_t_rate)
        logits_t = self.alpha_t + self._treatment_score(X, U)
        propensity = _sigmoid(logits_t)
        T = self.rng.binomial(1, propensity).astype(float)

        # Treatment effect (constant or heterogeneous)
        tau_x = (self.tau(X) if self.tau is not None else np.full(n, self.theta)).astype(float)

        # Outcome generation
        loc = self._outcome_location(X, T, U, tau_x)

        if self.outcome_type == "continuous":
            Y = loc + self.rng.normal(0, self.sigma_y, size=n)
            mu0 = self._outcome_location(X, np.zeros(n), U, np.zeros(n))
            mu1 = self._outcome_location(X, np.ones(n),  U, tau_x)
        elif self.outcome_type == "binary":
            # logit: logit P(Y=1|T,X) = loc
            p = _sigmoid(loc)
            Y = self.rng.binomial(1, p).astype(float)
            mu0 = _sigmoid(self._outcome_location(X, np.zeros(n), U, np.zeros(n)))
            mu1 = _sigmoid(self._outcome_location(X, np.ones(n),  U, tau_x))
        elif self.outcome_type == "poisson":
            # log link: log E[Y|T,X] = loc
            lam = np.exp(loc)
            Y = self.rng.poisson(lam).astype(float)
            mu0 = np.exp(self._outcome_location(X, np.zeros(n), U, np.zeros(n)))
            mu1 = np.exp(self._outcome_location(X, np.ones(n),  U, tau_x))
        else:
            raise ValueError("outcome_type must be 'continuous', 'binary', or 'poisson'")

        df = pd.DataFrame({"y": Y, "t": T})
        for j, name in enumerate(names):
            df[name] = X[:, j]
        # Useful ground-truth columns for evaluation
        df["propensity"] = propensity
        df["mu0"] = mu0
        df["mu1"] = mu1
        df["cate"] = mu1 - mu0
        return df


In [24]:
gen = CausalDatasetGenerator(
    theta=2.0,                       # constant ATE of +2
    beta_y=np.array([1.0, -0.5, 0.2]),
    beta_t=np.array([0.2,  0.4, -0.3]),
    alpha_y=0.0,
    target_t_rate=0.35,              # ~35% treated on average
    sigma_y=1.0,
    outcome_type="continuous",
    seed=42,
    confounder_specs=[
        {"name":"age", "dist":"normal", "mu":50, "sd":10},
        {"name":"smoker", "dist":"bernoulli", "p":0.3},
        {"name":"bmi", "dist":"normal", "mu":27, "sd":4},
    ],
)
df = gen.generate(10_000)
# df has: y, t, age, smoker, bmi, propensity, mu0, mu1, cate


In [25]:
df.head()

Unnamed: 0,y,t,age,smoker,bmi,propensity,mu0,mu1,cate
0,58.057239,0.0,53.047171,1.0,20.651225,0.846899,56.677416,58.677416,2.0
1,47.379975,0.0,39.600159,0.0,32.502713,0.007144,46.100702,48.100702,2.0
2,65.457219,1.0,57.504512,1.0,32.180623,0.297994,63.440637,65.440637,2.0
3,67.484632,0.0,59.405647,0.0,35.945336,0.118569,66.594714,68.594714,2.0
4,35.68131,0.0,30.489648,0.0,31.046206,0.001798,36.698889,38.698889,2.0


In [26]:
import numpy as np

import doubleml as dml

from doubleml.datasets import make_irm_data

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

ml_g = RandomForestRegressor(n_estimators=100, max_features=10, max_depth=5, min_samples_leaf=2)

ml_m = RandomForestClassifier(n_estimators=100, max_features=10, max_depth=5, min_samples_leaf=2)

np.random.seed(3333)


# Initialize DoubleMLData (data-backend of DoubleML)
data_dml = dml.DoubleMLData(df,
                                 y_col='y',
                                 d_cols='t',
                                 x_cols=['age', 'smoker', 'bmi'])


dml_irm_obj = dml.DoubleMLIRM(data_dml, ml_g, ml_m)

print(dml_irm_obj.fit())




------------------ Data Summary      ------------------
Outcome variable: y
Treatment variable(s): ['t']
Covariates: ['age', 'smoker', 'bmi']
Instrument variable(s): None
No. Observations: 10000


------------------ Score & Algorithm ------------------
Score function: ATE

------------------ Machine Learner   ------------------
Learner ml_g: RandomForestRegressor(max_depth=5, max_features=10, min_samples_leaf=2)
Learner ml_m: RandomForestClassifier(max_depth=5, max_features=10, min_samples_leaf=2)
Out-of-sample Performance:
Regression:
Learner ml_g0 RMSE: [[1.19899543]]
Learner ml_g1 RMSE: [[1.18290354]]
Classification:
Learner ml_m Log Loss: [[0.40454656]]

------------------ Resampling        ------------------
No. folds: 5
No. repeated sample splits: 1

------------------ Fit Summary       ------------------
       coef  std err          t          P>|t|     2.5 %   97.5 %
t  2.173678  0.08038  27.042559  4.672269e-161  2.016136  2.33122


In [27]:
(df['propensity'].between(0.05, 0.95)).mean()

np.float64(0.7383)