In [1]:
import numpy as np
import sklearn
import sklearn.pipeline as skpipe
# learners
import celer as cel
from lightgbm import LGBMClassifier, LGBMRegressor

# this module
from aipyw import AIPyW
from aipyw.dgp import dgp_binary, dgp_discrete, hainmueller

np.random.seed(42)

# Basic Demo


## Discrete Treatments

In [2]:
Y, D, X = dgp_discrete(
    n=100_000,
    p=4,
    treat_effects=np.array([0.0, 0.4, 0.5, 0.55]),
)
Y.shape, D.shape, X.shape

((100000,), (100000,), (100000, 10))

In [3]:
(Y[D == 1,].mean() - Y[D == 0,].mean(),
  Y[D == 2,].mean() - Y[D == 0,].mean(),
  Y[D == 3,].mean() - Y[D == 0,].mean()
)

(np.float64(-2.459176579755844),
 np.float64(0.9312130397763212),
 np.float64(2.8166707333716685))

Naive estimates badly biased.

In [4]:
%%time
doubledouble3 = AIPyW(riesz_method="ipw")
doubledouble3.fit(X, D, Y)
doubledouble3.summary()

CPU times: user 7.6 s, sys: 42.6 ms, total: 7.65 s
Wall time: 570 ms


{'1 vs 0': {'effect': np.float64(0.3993327187181337),
  'se': np.float64(1.1326191260309973e-05)},
 '2 vs 0': {'effect': np.float64(0.5002251443946422),
  'se': np.float64(1.2632634887430513e-05)},
 '3 vs 0': {'effect': np.float64(0.5520125191219095),
  'se': np.float64(1.316455977270563e-05)},
 '2 vs 1': {'effect': np.float64(0.10089242567650843),
  'se': np.float64(1.271426174924722e-05)},
 '3 vs 1': {'effect': np.float64(0.1526798004037758),
  'se': np.float64(1.1282258005654327e-05)},
 '3 vs 2': {'effect': np.float64(0.051787374727267396),
  'se': np.float64(1.3206565735523708e-05)}}

In [5]:
%%time
doubledouble3 = AIPyW(riesz_method="linear")
doubledouble3.fit(X, D, Y)
doubledouble3.summary()

CPU times: user 1min 2s, sys: 212 ms, total: 1min 2s
Wall time: 4.06 s


{'1 vs 0': {'effect': np.float64(0.3440241782297518),
  'se': np.float64(0.00031759006244523954)},
 '2 vs 0': {'effect': np.float64(0.43320743886748597),
  'se': np.float64(0.00032485666524239936)},
 '3 vs 0': {'effect': np.float64(0.47547790412244395),
  'se': np.float64(0.0003325522904819936)},
 '2 vs 1': {'effect': np.float64(0.08918326063773419),
  'se': np.float64(0.00024092384230735961)},
 '3 vs 1': {'effect': np.float64(0.13145372589269216),
  'se': np.float64(0.0002638624008049813)},
 '3 vs 2': {'effect': np.float64(0.042270465254957965),
  'se': np.float64(0.0002683575335978814)}}

In [6]:
%%time
doubledouble3 = AIPyW(riesz_method="balancing", bal_obj="quadratic")
doubledouble3.fit(X, D, Y)
doubledouble3.summary()

CPU times: user 11.2 s, sys: 15.7 ms, total: 11.2 s
Wall time: 792 ms


{'1 vs 0': {'effect': np.float64(0.3992624833348587),
  'se': np.float64(1.1131571660979226e-05)},
 '2 vs 0': {'effect': np.float64(0.4999967519495474),
  'se': np.float64(1.2429734209405745e-05)},
 '3 vs 0': {'effect': np.float64(0.55194785160569),
  'se': np.float64(1.3086491593015786e-05)},
 '2 vs 1': {'effect': np.float64(0.10073426861468866),
  'se': np.float64(1.26949544159626e-05)},
 '3 vs 1': {'effect': np.float64(0.15268536827083137),
  'se': np.float64(1.13142418777214e-05)},
 '3 vs 2': {'effect': np.float64(0.05195109965614269),
  'se': np.float64(1.3205131515576768e-05)}}

## Hainmueller (2012) Simulation study

Binary treatment, continuous outcome, 2 groups. We parametrize degree of overlap, functional form of outcome and treatment models. True effect is zero, so RMSE is easy to calculate.

In [7]:
def one_rep(
    n_samples, overlap_design, pscore_design, outcome_design, riesz_method, **kwargs
):
    # generate data
    y, d, X = hainmueller(
        n_samples=n_samples,
        overlap_design=overlap_design,
        pscore_design=pscore_design,
        outcome_design=outcome_design,
    )
    m1, m2 = LGBMRegressor(verbose=-1, n_jobs=1), LGBMClassifier(verbose=-1, n_jobs=1)
    # model instantiation
    aipw = AIPyW(
        propensity_model=m2, outcome_model=m1, riesz_method=riesz_method, **kwargs
    )
    aipw.fit(X, d, y, n_rff=100)
    return aipw.summary()["1 vs 0"]["effect"]

Favorable case: good overlap, linear pscore and outcome

In [8]:
%%time
one_rep(10_000, 2, 1, 1, "ipw")

CPU times: user 1.59 s, sys: 12 ms, total: 1.6 s
Wall time: 495 ms


np.float64(0.02722229312316711)

In [9]:
%%time
one_rep(10_000, 2, 1, 1, "linear")

CPU times: user 594 ms, sys: 0 ns, total: 594 ms
Wall time: 327 ms


np.float64(0.006689004288240835)

In [10]:
%%time
one_rep(10_000, 2, 1, 1, "kernel")

CPU times: user 23.4 s, sys: 32.1 ms, total: 23.5 s
Wall time: 1.73 s


np.float64(-0.0032402673454760135)

In [11]:
%%time
one_rep(10_000, 2, 1, 1, "automatic")

CPU times: user 1.82 s, sys: 11.9 ms, total: 1.83 s
Wall time: 364 ms


np.float64(0.02807227119629976)

In [12]:
%%time
one_rep(10_000, 2, 1, 1, "balancing")

CPU times: user 279 ms, sys: 7.97 ms, total: 287 ms
Wall time: 286 ms


np.float64(0.028048723082933033)

In [13]:
%%time
one_rep(10_000, 2, 1, 1, "balancing", bal_obj="entropy")

CPU times: user 307 ms, sys: 0 ns, total: 307 ms
Wall time: 307 ms


np.float64(-0.014473255833387634)

### Hard case: poor overlap, non-linear pscore and outcome

In [14]:
%%time
one_rep(10_000, 1, 3, 3, "ipw")

CPU times: user 409 ms, sys: 8 ms, total: 417 ms
Wall time: 417 ms


np.float64(-0.8496098275218612)

In [15]:
%%time
one_rep(10_000, 1, 3, 3, "linear")

CPU times: user 536 ms, sys: 3.86 ms, total: 540 ms
Wall time: 318 ms


np.float64(0.8783186745465631)

In [16]:
%%time
one_rep(10_000, 1, 3, 3, "kernel")

CPU times: user 23 s, sys: 36 ms, total: 23 s
Wall time: 1.69 s


np.float64(-2.280857509463784)

In [17]:
%%time
one_rep(10_000, 2, 1, 1, "automatic")

CPU times: user 1.83 s, sys: 4.09 ms, total: 1.84 s
Wall time: 330 ms


np.float64(0.05333008157288761)

In [18]:
%%time
one_rep(10_000, 1, 3, 3, "balancing")

CPU times: user 279 ms, sys: 0 ns, total: 279 ms
Wall time: 278 ms


np.float64(-1.9489090669045273)

In [19]:
%%time
one_rep(10_000, 1, 3, 3, "balancing", bal_obj="entropy")

CPU times: user 307 ms, sys: 3.99 ms, total: 311 ms
Wall time: 310 ms


np.float64(0.18399876550618385)

### all together

In [20]:
from joblib import Parallel, delayed

def compute_ate_rmse_parallel(
    n_samples,
    overlap_design,
    pscore_design,
    outcome_design,
    riesz_method,
    n_replications=100,
    n_jobs=-1,
):
    ate_estimates = Parallel(n_jobs=n_jobs)(
        delayed(one_rep)(
            n_samples, overlap_design, pscore_design, outcome_design, riesz_method
        )
        for _ in range(n_replications)
    )
    # Compute RMSE
    true_ate = 0
    rmse = np.sqrt(np.mean((np.array(ate_estimates) - true_ate) ** 2))
    return rmse

In [21]:
%%time
from itertools import product
params = np.arange(1, 4)
param_list = list(product(params, params, params, ['ipw', 'linear', 'kernel', 'automatic', 'balancing']))
res_dict = {}
for param in param_list:
  key = "_".join([str(x) for x in param])
  res_dict[key] = compute_ate_rmse_parallel(10_000, *param)

CPU times: user 24 s, sys: 3 s, total: 27 s
Wall time: 16min 24s


In [30]:
import pandas as pd
res_df = pd.DataFrame(
[
  list(product(['poor', 'good', 'medium'],
               ['linear', 'quad', 'trig'],
               ['linear', 'quad', 'nl'])),
  [v for k, v in res_dict.items() if k.endswith("ipw")],
  [v for k, v in res_dict.items() if k.endswith("linear")],
  [v for k, v in res_dict.items() if k.endswith("kernel")],
  [v for k, v in res_dict.items() if k.endswith("automatic")],
  [v for k, v in res_dict.items() if k.endswith("balancing")],
],
).T.infer_objects()
res_df.columns = ["design", "ipw", "linear", "kernel", "automatic", "constrained"]
# unpack design column
res_df['overlap_design'] = res_df['design'].apply(lambda x: x[0])
res_df['pscore_design'] = res_df['design'].apply(lambda x: x[1])
res_df['outcome_design'] = res_df['design'].apply(lambda x: x[2])
res_df.drop(columns=['design'], inplace=True)

res_df

Unnamed: 0,ipw,linear,kernel,automatic,constrained,overlap_design,pscore_design,outcome_design
0,0.068541,0.072096,0.035293,0.069347,0.066208,poor,linear,linear
1,0.065166,0.067032,0.034373,0.060328,0.066015,poor,linear,quad
2,13.757359,244.357287,8.889775,13.788107,31.737481,poor,linear,nl
3,0.054535,0.058653,0.034903,0.050918,0.055132,poor,quad,linear
4,0.038028,0.03766,0.021047,0.034115,0.037029,poor,quad,quad
5,10.769637,18.661587,26.108551,16.527473,61.495745,poor,quad,nl
6,0.024235,0.013379,0.009697,0.023839,0.025364,poor,trig,linear
7,0.024825,0.013214,0.011292,0.024514,0.022038,poor,trig,quad
8,13.712767,29.897234,6.162583,257.859393,24.983811,poor,trig,nl
9,0.035194,0.023804,0.01707,0.035736,0.033273,good,linear,linear
