# Chapter 9: Bridging Finite and Super-population Causal Inference

In [28]:
from joblib import Parallel, delayed

import numpy as np
import statsmodels.api as sm

np.random.seed(42)
%load_ext autoreload
%autoreload 1

%load_ext watermark
%watermark --iversions



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
numpy            : 1.23.5
pandas           : 2.1.1
matplotlib_inline: 0.1.6
statsmodels      : 0.13.5
matplotlib       : 3.8.0



In [29]:
def linestimator(Z, Y, X):
    X = (X - X.mean(axis=0)) / X.std(axis=0)
    n, p = X.shape
    # fully interacted OLS
    Xmat = np.c_[sm.add_constant(Z), X, Z.reshape(-1, 1) * X]
    m = sm.OLS(Y, Xmat).fit(cov_type="HC2")
    est, vehw = m.params[1], m.bse[1] ** 2
    # super-population correction
    inter = m.params[-p:]  # (β_1 - β_0) term - last p elements of coef
    # (β_1 - β_0)' Σ (β_1 - β_0) / n
    superCorr = np.sum(inter * (np.cov(X.T) @ inter)) / n
    vsuper = vehw + superCorr
    return est, np.sqrt(vehw), np.sqrt(vsuper)

In [30]:
def onerepl(*args):
    n = 500
    X = np.random.normal(0, 1, n * 2).reshape(n, 2)
    Y0 = X[:, 0] + X[:, 0] ** 2 + np.random.uniform(-0.5, 0.5, n)
    Y1 = X[:, 1] + X[:, 1] ** 2 + np.random.uniform(-1, 1, n)
    Z = np.random.binomial(1, 0.6, n)
    Y = Y0 * (1 - Z) + Y1 * Z
    return linestimator(Z, Y, X)

In [31]:
onerepl()

(0.052230404017171474, 0.1475302340448403, 0.1633386978782156)

In [32]:
nrep, k = 2000, 8
results = Parallel(n_jobs=k)(delayed(onerepl)(i) for i in range(nrep))
simres = np.vstack(results)

In [33]:
# bias, estimated EHW SE, estimated super-population SE
simres[:, 0].mean(), simres[:, 1].mean(), simres[:, 2].mean()

(0.002113041527341019, 0.13520566776990306, 0.15005198871562947)

In [34]:
# empirical SD
simres[:, 0].std()

0.1507286832895897

In [35]:
# EHW coverage
np.mean(
    (simres[:, 0] - 1.96 * simres[:, 1]) * (simres[:, 0] + 1.96 * simres[:, 1]) <= 0
)

0.917

EHW has below nominal coverage for superpopulation.

In [36]:
# superpop coverage
np.mean(
    (simres[:, 0] - 1.96 * simres[:, 2]) * (simres[:, 0] + 1.96 * simres[:, 2]) <= 0
)

0.9505

Superpopn is above nom coverage for superpopulation.