# Instrumental Variables (IV) with `stochtree`

## Introduction

Here we consider a causal inference problem with a binary treatment and a binary outcome where there is unobserved confounding, but an exogenous instrument is available (also binary). This problem will require a number of extensions to the basic BART model, all of which can be implemented straightforwardly as Gibbs samplers using `stochtree`. We'll go through all of the model fitting steps in quite a lot of detail here.

## Background

To be concrete, suppose we wish to measure the effect of receiving a flu vaccine on the probability of getting the flu. Individuals who opt to get a flu shot differ in many ways from those that don't, and these lifestyle differences presumably also affect their respective chances of getting the flu. Consequently, comparing the percentage of individuals who get the flu in the vaccinated and unvaccinated groups does not give a clear picture of the vaccine efficacy. 

However, a so-called encouragement design can be implemented, where some individuals are selected at random to be given some extra incentive to get a flu shot (free clinics at the workplace or a personalized reminder, for example). Studying the impact of this randomized encouragement allows us to tease apart the impact of the vaccine from the confounding factors, at least to some extent. This exact problem has been considered several times in the literature, starting with McDonald, Hiu, and Tierny (1992) with follow-on analysis by Hirano et. al. (2000), Richardson and Robins (2011), and Imbens and Rubin (2015).

Our analysis here follows the Bayesian nonparametric approach described in the supplement to Hahn, Murray, and Manolopoulou (2016).

First, load requisite libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from stochtree import (
    RNG,
    Dataset,
    Forest,
    ForestContainer,
    ForestSampler,
    Residual, 
    ForestModelConfig, 
    GlobalModelConfig,
)

### Notation

Let $V$ denote the treatment variable (as in "vaccine"). Let $Y$ denote the response variable (getting the flu), $Z$ denote the instrument (encouragement or reminder to get a flu shot), and $X$ denote an additional observable covariate (for instance, patient age).

Further, let $S$ denote the so-called *principal strata*, which is an exhaustive characterization of how individuals' might be affected by the encouragement regarding the flu shot. Some people will get a flu shot no matter what: these are the *always takers* (a). Some people will not get the flu shot no matter what: these are the *never takers* (n). For both always-takers and never-takers, the randomization of the encouragement is irrelevant and our data set contains no always takers who skipped the vaccine and no never takers who got the vaccine and so the treatment effect of the vaccine in these groups is fundamentally non-identifiable. 

By contrast, we also have *compliers* (c): folks who would not have gotten the shot but for the fact that they were encouraged to do so. These are the people about whom our randomized encouragement provides some information, because they are precisely the ones that have been randomized to treatment. 

Lastly, we could have *defiers* (d): contrarians who who were planning on getting the shot, but -- upon being reminded -- decided not to! For our analysis we will do the usual thing of assuming that there are no defiers. And because we are going to simulate our data, we can make sure that this assumption is true.

## The causal diagram

The causal diagram for this model can be expressed as follows. Here we are considering one confounder and moderator variable ($X$), which is the patient's age. In our data generating process (which we know because this is a simulation demonstration) higher age will make it more likely that a person is an always taker or complier and less likely that they are a nevertaker, which in turn has an effect on flu risk. We stipulate here that always takers are at lower risk and never takers at higher risk. Simultaneously, age has an increasing and then decreasing direct effect on flu risk; very young and very old are at higher risk, while young and middle age adults are at lower risk. In this DGP the flu efficacy has a multiplicative effect, reducing flu risk as a fixed proportion of baseline risk -- accordingly, the treatment effect (as a difference) is nonlinear in Age (for each principal stratum).

![IV_CDAG](IV_CDAG.png)

The biggest question about this graph concerns the dashed red arrow from the putative instrument $Z$ to the outcome (flu). I say "putative" because if that dashed red arrow is there, then technically $Z$ is not a valid instrument. The assumption/assertion that there is no dashed red arrow is called the "exclusion restriction". In this vignette, we will explore what sorts of inferences are possible if we remain agnostic about the presence or absence of that dashed red arrow.

## Potential outcomes

There are two relevant potential outcomes in an instrumental variables analysis, corresponding to the causal effect of the instrument on the treatment and the causal effect of the treatment on the outcome. In this example, that is the effect of the reminder/encouragement on vaccine status and the effect of the vaccine itself on the flu. The notation is $V(Z)$ and $Y(V(Z),Z)$ respectively, so that we have six distinct random variables: $V(0)$, $V(1)$, $Y(0,0)$, $Y(1,0)$, $Y(0,1)$ and $Y(1,1)$. The problem -- sometimes called the *fundamental problem of causal inference* -- is that some of these random variables can never be seen simultaneously, they are observationally mutually exclusive. For this reason, it may be helpful to think about causal inference as a missing data problem, as depicted in the following table.

| $i$          | $Z_i$          | $V_i(0)$      | $V_i(1)$      | $Y_i(0,0)$    | $Y_i(1,0)$    | $Y_i(0,1)$    | $Y_i(1,1)$    |
|    :---:     |     :---:      |     :---:     |     :---:     |     :---:     |     :---:     |     :---:     |     :---:     |
| 1            | 1              | ?             | 1             | ?             | ?             | ?             | 0             |
| 2            | 0              | 1             | ?             | ?             | 1             | ?             | ?             |
| 3            | 0              | 0             | ?             | 1             | ?             | ?             | ?             |
| 4            | 1              | ?             | 0             | ?             | ?             | 0             | ?             |

Likewise, with this notation we can formally define the principal strata:

| $V_i(0)$      | $V_i(1)$      | $S_i$              |
|     :---:     |     :---:     |     :---:          |
| 0             | 0             | Never Taker ($n$)  |
| 1             | 1             | Always Taker ($a$) |
| 0             | 1             | Complier ($c$)     |
| 1             | 0             | Defier ($d$)       |

## Estimands and Identification

Let $\pi_s(x)$ denote the conditional (on $x$) probability that an individual belongs to principal stratum $s$:

\begin{equation}
\pi_s(x)=\operatorname{Pr}(S=s \mid X=x),
\end{equation}

and let $\gamma_s^{v z}(x)$ denote the potential outcome probability for given values $v$ and $z$:

\begin{equation}
\gamma_s^{v z}(x)=\operatorname{Pr}(Y(v, z)=1 \mid S=s, X=x)
\end{equation}

Various estimands of interest may be expressed in terms of the functions $\gamma_c^{vz}(x)$. In particular, the complier conditional average treatment effect $$\gamma_c^{1,z}(x) - \gamma_c^{0,z}(x)$$ is the ultimate goal (for either $z=0$ or $z=1$). Under an exclusion restriction, we would have $\gamma_s^{vz}(x) = \gamma_s^{v}(x)$ and the reminder status $z$ itself would not matter. In that case, we can estimate $$\gamma_c^{1,z}(x) - \gamma_c^{0,z}$$ and $$\gamma_c^{1,1}(x) - \gamma_c^{0,0}(x).$$ This latter quantity is called the complier intent-to-treat effect, or $ITT_c$, and it can be partially identify even if the exclusion restriction is violated, as follows. 

The left-hand side of the following system of equations are all estimable quantities that can be learned from observable data, while the right hand side expressions involve the unknown functions of interest,  $\gamma_s^{vz}(x)$:

\begin{equation}
\begin{aligned}
p_{1 \mid 00}(x) = \operatorname{Pr}(Y=1 \mid V=0, Z=0, X=x)=\frac{\pi_c(x)}{\pi_c(x)+\pi_n(x)} \gamma_c^{00}(x)+\frac{\pi_n(x)}{\pi_c(x)+\pi_n(x)} \gamma_n^{00}(x) \\
p_{1 \mid 11}(x) =\operatorname{Pr}(Y=1 \mid V=1, Z=1, X=x)=\frac{\pi_c(x)}{\pi_c(x)+\pi_a(x)} \gamma_c^{11}(x)+\frac{\pi_a(x)}{\pi_c(x)+\pi_a(x)} \gamma_a^{11}(x) \\
p_{1 \mid 01}(x) =\operatorname{Pr}(Y=1 \mid V=0, Z=1, X=x)=\frac{\pi_d(x)}{\pi_d(x)+\pi_n(x)} \gamma_d^{01}(x)+\frac{\pi_n(x)}{\pi_d(x)+\pi_n(x)} \gamma_n^{01}(x) \\
p_{1 \mid 10}(x) =\operatorname{Pr}(Y=1 \mid V=1, Z=0, X=x)=\frac{\pi_d(x)}{\pi_d(x)+\pi_a(x)} \gamma_d^{10}(x)+\frac{\pi_a(x)}{\pi_d(x)+\pi_a(x)} \gamma_a^{10}(x)
\end{aligned}
\end{equation}

Furthermore, we have

\begin{equation}
\begin{aligned}
\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x)+\pi_d(x)\\
\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x)
\end{aligned}
\end{equation}

Under the monotonicy assumption, $\pi_d(x) = 0$ and these expressions simplify somewhat.

\begin{equation}
\begin{aligned}
p_{1 \mid 00}(x)&=\frac{\pi_c(x)}{\pi_c(x)+\pi_n(x)} \gamma_c^{00}(x)+\frac{\pi_n(x)}{\pi_c(x)+\pi_n(x)} \gamma_n^{00}(x) \\
p_{1 \mid 11}(x)&=\frac{\pi_c(x)}{\pi_c(x)+\pi_a(x)} \gamma_c^{11}(x)+\frac{\pi_a(x)}{\pi_c(x)+\pi_a(x)} \gamma_a^{11}(x) \\
p_{1 \mid 01}(x)&=\gamma_n^{01}(x) \\
p_{1 \mid 10}(x)&=\gamma_a^{10}(x)
\end{aligned}
\end{equation}

and

\begin{equation}
\begin{aligned}
\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x)\\
\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x)
\end{aligned}
\end{equation}

The exclusion restriction would dictate that $\gamma_s^{01}(x) = \gamma_s^{00}(x)$ and $\gamma_s^{11}(x) = \gamma_s^{10}(x)$ for all $s$. This has two implications. One, $\gamma_n^{01}(x) = \gamma_n^{00}(x)$ and $\gamma_a^{10}(x) = \gamma_a^{11}(x)$,and because the left-hand terms are identified, this permits $\gamma_c^{11}(x)$ and $\gamma_c^{00}(x)$ to be solved for by substitution. Two, with these two quantities solved for, we also have the two other quantities (the different settings of $z$), since $\gamma_c^{11}(x) = \gamma_c^{10}(x)$ and $\gamma_c^{00}(x) = \gamma_c^{01}(x)$. Consequently, both of our estimands from above can be estimated:

$$\gamma_c^{11}(x) - \gamma_c^{01}(x)$$
and 

$$\gamma_c^{10}(x) - \gamma_c^{00}(x)$$
because they are both (supposing the exclusion restriction holds) the same as

$$\gamma_c^{11}(x) - \gamma_c^{00}(x).$$
If the exclusion restriction does *not* hold, then the three above treatment effects are all (potentially) distinct and not much can be said about the former two. The latter one, the $ITT_c$, however, can be partially identified, by recognizing that the first two equations (in our four equation system) provide non-trivial bounds based on the fact that while $\gamma_c^{11}(x)$ and $\gamma_c^{00}(x)$ are no longer identified, as probabilities both must lie between 0 and 1. Thus, 

\begin{equation}
\begin{aligned}
	\max\left(
		0, \frac{\pi_c(x)+\pi_n(x)}{\pi_c(x)}p_{1\mid 00}(x) - \frac{\pi_n(x)}{\pi_c(x)}
	\right)
&\leq\gamma^{00}_c(x)\leq
	\min\left(
		1, \frac{\pi_c(x)+\pi_n(x)}{\pi_c(x)}p_{1\mid 00}(x)
	\right)\\\\
%
\max\left(
  0, \frac{\pi_a(x)+\pi_c(x)}{\pi_c(x)}p_{1\mid 11}(x) - \frac{\pi_a(x)}{\pi_c(x)}
\right)
&\leq\gamma^{11}_c(x)\leq
\min\left(
  1, \frac{\pi_a(x)+\pi_c(x)}{\pi_c(x)}p_{1\mid 11}(x)
\right)
\end{aligned}
\end{equation}

The point of all this is that the data (plus a no-defiers assumption) lets us estimate all the necessary inputs to these upper and lower bounds on $\gamma^{11}_c(x)$ and $\gamma^{00}_c(x)$ which in turn define our estimand. What remains is to estimate those inputs, as functions of $x$, and to do so while enforcing the monotonicty restriction $$\operatorname{Pr}(V=1 \mid Z=0, X=x)=\pi_a(x) \leq 
\operatorname{Pr}(V=1 \mid Z=1, X=x)=\pi_a(x)+\pi_c(x).$$

We can do all of this with calls to stochtree from R (or Python). But first, let's generate some test data. 

## Simulate the data

Start with some initial setup / housekeeping

In [None]:
# Size of the training sample
n = 20000

# To set the seed for reproducibility/illustration purposes, replace "None" with a positive integer
random_seed = None
if random_seed is not None:
    rng = np.random.default_rng(random_seed)
else:
    rng = np.random.default_rng()

First, we generate the instrument exogenously

In [None]:
z = rng.binomial(n=1, p=0.5, size=n)

Next, we generate the covariate. (For this example, let's think of it as patient age, although we are generating it from a uniform distribution between 0 and 3, so you have to imagine that it has been pre-standardized to this scale. It keeps the DGPs cleaner for illustration purposes.)

In [None]:
p_X = 1
X = rng.uniform(low=0., high=3., size=(n,p_X))
x = X[:,0] # for ease of reference later

Next, we generate the principal strata $S$ based on the observed value of $X$. We generate it according to a logistic regression with two coefficients per strata, an intercept and a slope. Here, these coefficients are set so that the probability of being a never taker decreases with age.

In [None]:
alpha_a = 0
beta_a = 1

alpha_n = 1
beta_n = -1

alpha_c = 1
beta_c = 1

# Define function (a logistic model) to generate Pr(S = s | X = x)
def pi_s(xval, alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c):
    w_a = np.exp(alpha_a + beta_a*xval)
    w_n = np.exp(alpha_n + beta_n*xval)
    w_c = np.exp(alpha_c + beta_c*xval)
    w = np.column_stack((w_a, w_n, w_c))
    w_rowsum = np.sum(w, axis=1, keepdims=True)
    return np.divide(w, w_rowsum)
    
# Sample principal strata based on observed probabilities
strata_probs = pi_s(X[:,0], alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c)
s = np.empty_like(X[:,0], dtype=str)
for i in range(s.size):
    s[i] = rng.choice(a=['a','n','c'], size=1, p=strata_probs[i,:])[0]


Next, we generate the treatment variable, here denoted $V$ (for "vaccine"), as a *deterministic* function of $S$ and $Z$; this is what gives the principal strata their meaning.

In [None]:
v = 1*(s=='a') + 0*(s=='n') + z*(s=="c") + (1-z)*(s == "d")

Finally, the outcome structural model is specified, based on which the outcome is sampled. By varying this function in particular ways, we can alter the identification conditions.

In [None]:
def gamfun(xval, vval, zval, sval):
    """
    If this function depends on zval, then exclusion restriction is violated.
    If this function does not depend on sval, then IV analysis wasn't necessary.
    If this function does not depend on x, then there are no HTEs.
    """
    baseline = norm.cdf(2 - 1*xval - 2.5*((xval-1.5)**2) - 0.5*zval + 1*(sval=="n") - 1*(sval=="a"))
    return baseline - 0.5*vval*baseline

y = rng.binomial(n=1, p=gamfun(X[:,0],v,z,s), size=n)

Lastly, we perform some organization for our supervised learning algorithms later on.

In [None]:
# Concatenate X, v and z for our supervised learning algorithms
Xall = np.concatenate((X, np.column_stack((v,z))), axis=1)

# Update the size of "X" to be the size of Xall
p_X = p_X + 2

# For the monotone probit model it is necessary to sort the observations so that the Z=1 cases are all together
# at the start of the outcome vector.  
sort_index = np.argsort(z)[::-1]
X = X[sort_index,:]
Xall = Xall[sort_index,:]
z = z[sort_index]
v = v[sort_index]
s = s[sort_index]
y = y[sort_index]
x = x[sort_index]

Now let's see if we can recover these functions from the observed data.

## Fit the outcome model

We have to fit three models here, the treatment models: $\operatorname{Pr}(V = 1 | Z = 1, X=x)$ and $\operatorname{Pr}(V = 1 | Z = 0,X = x)$, subject to the monotonicity constraint  $\operatorname{Pr}(V = 1 | Z = 1, X=x) \geq \operatorname{Pr}(V = 1 | Z = 0,X = x)$, and an outcome model $\operatorname{Pr}(Y = 1 | Z = 1, V = 1, X = x)$. All of this will be done with stochtree. 

The outcome model is fit with a single (S-learner) BART model. This part of the model could be fit as a T-Learner or as a BCF model. Here we us an S-Learner for simplicity. Both models are probit models, and use the well-known Albert and Chib (1993) data augmentation Gibbs sampler. This section covers the more straightforward outcome model. The next section describes how the monotonicity constraint is handled with a data augmentation Gibbs sampler. 

These models could (and probably should) be wrapped as functions. Here they are implemented as scripts, with the full loops shown. The output -- at the end of the loops -- are stochtree forest objects from which we can extract posterior samples and generate predictions. In particular, the $ITT_c$ will be constructed using posterior counterfactual predictions derived from these forest objects. 

We begin by setting a bunch of hyperparmeters and instantiating the forest objects to be operated upon in the main sampling loop. We also initialize the latent variables.

In [None]:
# Fit the BART model for Pr(Y = 1 | Z = 1, V = 1, X = x)

# Set number of iterations
num_warmstart = 10
num_mcmc = 1000
num_samples = num_warmstart + num_mcmc

# Set a bunch of hyperparameters. These are ballpark default values.
alpha = 0.95
beta = 2
min_samples_leaf = 1
max_depth = 20
num_trees = 50
cutpoint_grid_size = 100
global_variance_init = 1.
tau_init = 0.5
leaf_prior_scale = np.array([[tau_init]])
leaf_regression = False
feature_types = np.append(np.repeat(0, p_X - 2), [1,1]).astype(int)
var_weights = np.repeat(1.0/p_X, p_X)
outcome_model_type = 0

# C++ dataset
forest_dataset = Dataset()
forest_dataset.add_covariates(Xall)

# Random number generator (std::mt19937)
if random_seed is not None:
    cpp_rng = RNG(random_seed)
else:
    cpp_rng = RNG()

# Sampling data structures
forest_model_config = ForestModelConfig(
    feature_types = feature_types, 
    num_trees = num_trees, 
    num_features = p_X, 
    num_observations = n, 
    variable_weights = var_weights, 
    leaf_dimension = 1, 
    alpha = alpha, 
    beta = beta, 
    min_samples_leaf = min_samples_leaf, 
    max_depth = max_depth, 
    leaf_model_type = outcome_model_type, 
    leaf_model_scale = leaf_prior_scale, 
    cutpoint_grid_size = cutpoint_grid_size
)
global_model_config = GlobalModelConfig(global_error_variance=1.0)
forest_sampler = ForestSampler(
    forest_dataset, global_model_config, forest_model_config
)

# Container of forest samples
forest_samples = ForestContainer(num_trees, 1, True, False)

# "Active" forest state
active_forest = Forest(num_trees, 1, True, False)

# Initialize the latent outcome zed
n1 = np.sum(y)
zed = 0.25*(2.0*y - 1.0)

# C++ outcome variable
outcome = Residual(zed)

# Initialize the active forest and subtract each root tree's predictions from outcome
forest_init_val = np.array([0.0])
forest_sampler.prepare_for_sampler(
    forest_dataset,
    outcome,
    active_forest,
    outcome_model_type,
    forest_init_val,
)

Now we enter the main loop, which involves only two steps: sample the forest, given the latent utilies, then sample the latent utilities given the estimated conditional means defined by the forest and its parameters. 

In [None]:
gfr_flag = True
for i in range(num_samples):
    # The first num_warmstart iterations use the grow-from-root algorithm of He and Hahn
    if i >= num_warmstart:
        gfr_flag = False
    
    # Sample forest
    forest_sampler.sample_one_iteration(
        forest_samples, active_forest, forest_dataset, outcome, cpp_rng, 
        global_model_config, forest_model_config, keep_forest=True, gfr = gfr_flag
    )

    # Get the current means
    eta = np.squeeze(forest_samples.predict_raw_single_forest(forest_dataset, i))

    # Sample latent normals, truncated according to the observed outcome y
    mu0 = eta[y == 0]
    mu1 = eta[y == 1]
    u0 = rng.uniform(
        low=0.0,
        high=norm.cdf(0 - mu0),
        size=n-n1,
    )
    u1 = rng.uniform(
        low=norm.cdf(0 - mu1),
        high=1.0,
        size=n1,
    )
    zed[y == 0] = mu0 + norm.ppf(u0)
    zed[y == 1] = mu1 + norm.ppf(u1)

    # Update outcome
    new_outcome = np.squeeze(zed) - eta
    outcome.update_data(new_outcome)

## Fit the monotone probit model(s)

The monotonicty constraint relies on a data augmentation as described in Papakostas et al (2023). The implementation of this sampler is inherently cumbersome, as one of the "data" vectors is constructed from some observed data and some latent data and there are two forest objects, one of which applies to all of the observations and one of which applies to only those observations with $Z = 0$. We go into more details about this sampler in a dedicated vignette. Here we include the code, but without producing the equations derived in Papakostas (2023). What is most important is simply that

\begin{equation}
\begin{aligned}
\operatorname{Pr}(V=1 \mid Z=0, X=x)&=\pi_a(x) = \Phi_f(x)\Phi_h(x),\\
\operatorname{Pr}(V=1 \mid Z=1, X=x)&=\pi_a(x)+\pi_c(x) = \Phi_f(x),
\end{aligned}
\end{equation}
where $\Phi_{\mu}(x)$ denotes the normal cumulative distribution function with mean $\mu(x)$ and variance 1. 

We first create a secondary data matrix for the $Z=0$ group only. We also set all of the hyperparameters and initialize the latent variables.

In [None]:
# Fit the monotone probit model to the treatment such that Pr(V = 1 | Z = 1, X=x) >= Pr(V = 1 | Z = 0,X = x) 
X_h = X[z==0,:]
n0 = np.sum(z==0)
n1 = np.sum(z==1)

num_trees_f = 50
num_trees_h = 20
feature_types = np.repeat(0, p_X-2).astype(int)
var_weights = np.repeat(1.0/(p_X - 2.0), p_X - 2)
cutpoint_grid_size = 100
global_variance_init = 1.
tau_init_f = 1/num_trees_f
tau_init_h = 1/num_trees_h
leaf_prior_scale_f = np.array([[tau_init_f]])
leaf_prior_scale_h = np.array([[tau_init_h]])
leaf_regression = False # fit a constant leaf mean BART model

# Instantiate the C++ dataset objects
forest_dataset_f = Dataset()
forest_dataset_f.add_covariates(X)
forest_dataset_h = Dataset()
forest_dataset_h.add_covariates(X_h)

# Tell it we're fitting a normal BART model
outcome_model_type = 0

# Set up model configuration objects
forest_model_config_f = ForestModelConfig(
    feature_types = feature_types, 
    num_trees = num_trees_f, 
    num_features = X.shape[1], 
    num_observations = n, 
    variable_weights = var_weights, 
    leaf_dimension = 1, 
    alpha = alpha, 
    beta = beta, 
    min_samples_leaf = min_samples_leaf, 
    max_depth = max_depth, 
    leaf_model_type = outcome_model_type, 
    leaf_model_scale = leaf_prior_scale_f, 
    cutpoint_grid_size = cutpoint_grid_size
)
forest_model_config_h = ForestModelConfig(
    feature_types = feature_types, 
    num_trees = num_trees_h, 
    num_features = X_h.shape[1], 
    num_observations = n0, 
    variable_weights = var_weights, 
    leaf_dimension = 1, 
    alpha = alpha, 
    beta = beta, 
    min_samples_leaf = min_samples_leaf, 
    max_depth = max_depth, 
    leaf_model_type = outcome_model_type, 
    leaf_model_scale = leaf_prior_scale_h, 
    cutpoint_grid_size = cutpoint_grid_size
)
global_model_config = GlobalModelConfig(global_error_variance=global_variance_init)

# Instantiate the sampling data structures
forest_sampler_f = ForestSampler(
    forest_dataset_f, global_model_config, forest_model_config_f
)
forest_sampler_h = ForestSampler(
    forest_dataset_h, global_model_config, forest_model_config_h
)

# Instantiate containers of forest samples
forest_samples_f = ForestContainer(num_trees_f, 1, True, False)
forest_samples_h = ForestContainer(num_trees_h, 1, True, False)

# Instantiate "active" forests
active_forest_f = Forest(num_trees_f, 1, True, False)
active_forest_h = Forest(num_trees_h, 1, True, False)

# Set algorithm specifications 
# these are set in the earlier script for the outcome model; number of draws needs to be commensurable 

# num_warmstart = 40
# num_mcmc = 5000
# num_samples = num_warmstart + num_mcmc

# Initialize the Markov chain

# Initialize (R0, R1), the latent binary variables that enforce the monotonicty 
v1 = v[z==1]
v0 = v[z==0]

R1 = np.empty(n0, dtype=float)
R0 = np.empty(n0, dtype=float)

R1[v0==1] = 1
R0[v0==1] = 1

nv0 = np.sum(v0==0)
R1[v0 == 0] = 0
R0[v0 == 0] = rng.choice([0,1], size = nv0)

# The first n1 observations of vaug are actually observed
# The next n0 of them are the latent variable R1
vaug = np.append(v1, R1)

# Initialize the Albert and Chib latent Gaussian variables
z_f = (2.0*vaug - 1.0)
z_h = (2.0*R0 - 1.0)
z_f = z_f/np.std(z_f)
z_h = z_h/np.std(z_h)

# Pass these variables to the BART models as outcome variables
outcome_f = Residual(z_f)
outcome_h = Residual(z_h)

# Initialize active forests to constant (0) predictions
forest_init_val_f = np.array([0.0])
forest_sampler_f.prepare_for_sampler(
    forest_dataset_f,
    outcome_f,
    active_forest_f,
    outcome_model_type,
    forest_init_val_f,
)
forest_init_val_h = np.array([0.0])
forest_sampler_h.prepare_for_sampler(
    forest_dataset_h,
    outcome_h,
    active_forest_h,
    outcome_model_type,
    forest_init_val_h,
)

Now we run the main sampling loop, which consists of three key steps: sample the BART forests, given the latent probit utilities, sampling the latent binary outcome pairs (this is the step that is necessary for enforcing monotonicity), given the forest predictions and the latent utilities, and finally sample the latent utilities.

In [None]:
# PART IV: run the Markov chain 

# Initialize the Markov chain with num_warmstart grow-from-root iterations
gfr_flag = True
for i in range(num_samples):
    # Switch over to random walk Metropolis-Hastings tree updates after num_warmstart
    if i >= num_warmstart:
        gfr_flag = False
    
    # Step 1: Sample the BART forests

    # Sample forest for the function f based on (y_f, R1)
    forest_sampler_f.sample_one_iteration(
        forest_samples_f, active_forest_f, forest_dataset_f, outcome_f, cpp_rng, 
        global_model_config, forest_model_config_f, keep_forest=True, gfr = gfr_flag
    )

    # Sample forest for the function h based on outcome R0
    forest_sampler_h.sample_one_iteration(
        forest_samples_h, active_forest_h, forest_dataset_h, outcome_h, cpp_rng, 
        global_model_config, forest_model_config_h, keep_forest=True, gfr = gfr_flag
    )

    # Get the current means
    eta_f = np.squeeze(forest_samples_f.predict_raw_single_forest(forest_dataset_f, i))
    eta_h = np.squeeze(forest_samples_h.predict_raw_single_forest(forest_dataset_h, i))

    # Step 2: sample the latent binary pair (R0, R1) given eta_h, eta_f, and y_g

    # Three cases: (0,0), (0,1), (1,0)
    w1 = (1 - norm.cdf(eta_h[v0==0]))*(1 - norm.cdf(eta_f[n1 + np.where(v0==0)]))
    w2 = (1 - norm.cdf(eta_h[v0==0]))*norm.cdf(eta_f[n1 + np.where(v0==0)])
    w3 = norm.cdf(eta_h[v0==0])*(1 - norm.cdf(eta_f[n1 + np.where(v0==0)]))

    s = w1 + w2 + w3
    w1 = w1/s
    w2 = w2/s
    w3 = w3/s

    u = rng.uniform(low=0,high=1,size=np.sum(v0==0))
    temp = 1*(np.squeeze(u < w1)) + 2*(np.squeeze((u > w1) & (u < (w1 + w2)))) + 3*(np.squeeze(u > (w1 + w2)))

    R1[v0==0] = 1*(temp==2)
    R0[v0==0] = 1*(temp==3)

    # Redefine y with the updated R1 component
    vaug = np.append(v1, R1)

    # Step 3: sample the latent normals, given (R0, R1) and y_f

    # First z0
    mu1 = eta_h[R0==1]
    U1 = rng.uniform(
        low=norm.cdf(0 - mu1), 
        high=1,
        size=np.sum(R0).astype(int)
    )
    z_h[R0==1] = mu1 + norm.ppf(U1)

    mu0 = eta_h[R0==0]
    U0 = rng.uniform(
        low=0, 
        high=norm.cdf(0 - mu0),
        size=(n0 - np.sum(R0)).astype(int)
    )
    z_h[R0==0] = mu0 + norm.ppf(U0)

    # Then z1
    mu1 = eta_f[vaug==1]
    U1 = rng.uniform(
        low=norm.cdf(0 - mu1), 
        high=1,
        size=np.sum(vaug).astype(int)
    )
    z_f[vaug==1] = mu1 + norm.ppf(U1)

    mu0 = eta_f[vaug==0]
    U0 = rng.uniform(
        low=0, 
        high=norm.cdf(0 - mu0),
        size=(n - np.sum(vaug)).astype(int)
    )
    z_f[vaug==0] = mu0 + norm.ppf(U0)

    # Propagate the updated outcomes through the BART models
    new_outcome_h = np.squeeze(z_h) - eta_h
    outcome_h.update_data(new_outcome_h)

    new_outcome_f = np.squeeze(z_f) - eta_f
    outcome_f.update_data(new_outcome_f)

## Extracting the estimates and plotting the results.

Now for the most interesting part, which is taking the stochtree BART model fits and producing the causal estimates of interest. 

First we set up our grid for plotting the functions in $X$. This is possible in this example because the moderator, age, is one dimensional; in may applied problems this will not be the case and visualization will be substantially trickier. 

In [None]:
# Extract the credible intervals for the conditional treatment effects as a function of x.
# We use a grid of values for plotting, with grid points that are typically fewer than the number of observations.

ngrid = 200
xgrid = np.linspace(start=0.1, stop=2.9, num=ngrid)
X_11 = np.column_stack((xgrid, np.ones(ngrid), np.ones(ngrid)))
X_00 = np.column_stack((xgrid, np.zeros(ngrid), np.zeros(ngrid)))
X_01 = np.column_stack((xgrid, np.zeros(ngrid), np.ones(ngrid)))
X_10 = np.column_stack((xgrid, np.ones(ngrid), np.zeros(ngrid)))

Next, we compute the truth function evaluations on this plotting grid, using the functions defined above when we generated our data.

In [None]:
# Compute the true conditional outcome probabilities for plotting
pi_strat = pi_s(xgrid, alpha_a, beta_a, alpha_n, beta_n, alpha_c, beta_c)
w_a = pi_strat[:,0]
w_n = pi_strat[:,1]
w_c = pi_strat[:,2]

w = (w_c/(w_a + w_c))
p11_true = w*gamfun(xgrid,1,1,"c") + (1-w)*gamfun(xgrid,1,1,"a")

w = (w_c/(w_n + w_c))
p00_true = w*gamfun(xgrid,0,0,"c") + (1-w)*gamfun(xgrid,0,0,"n")

# Compute the true ITT_c for plotting and comparison
itt_c_true = gamfun(xgrid,1,1,"c") - gamfun(xgrid,0,0,"c")

# Compute the true LATE for plotting and comparison
LATE_true0 = gamfun(xgrid,1,0,"c") - gamfun(xgrid,0,0,"c")
LATE_true1 = gamfun(xgrid,1,1,"c") - gamfun(xgrid,0,1,"c")

Next we populate the data structures for stochtree to operate on, call the predict functions to extract the predictions, convert them to the probability scale using the built in pnorm function.

In [None]:
# Datasets for counterfactual predictions
forest_dataset_grid = Dataset()
forest_dataset_grid.add_covariates(np.expand_dims(xgrid, 1))
forest_dataset_11 = Dataset()
forest_dataset_11.add_covariates(X_11)
forest_dataset_00 = Dataset()
forest_dataset_00.add_covariates(X_00)
forest_dataset_10 = Dataset()
forest_dataset_10.add_covariates(X_10)
forest_dataset_01 = Dataset()
forest_dataset_01.add_covariates(X_01)

# Forest predictions
preds_00 = forest_samples.predict(forest_dataset_00)
preds_11 = forest_samples.predict(forest_dataset_11)
preds_01 = forest_samples.predict(forest_dataset_01)
preds_10 = forest_samples.predict(forest_dataset_10)

# Probability transformations
phat_00 = norm.cdf(preds_00)
phat_11 = norm.cdf(preds_11)
phat_01 = norm.cdf(preds_01)
phat_10 = norm.cdf(preds_10)

preds_ac = forest_samples_f.predict(forest_dataset_grid)
phat_ac = norm.cdf(preds_ac)

preds_adj = forest_samples_h.predict(forest_dataset_grid)
phat_a = norm.cdf(preds_ac) * norm.cdf(preds_adj)
phat_c = phat_ac - phat_a
phat_n = 1 - phat_ac

Now we may plot posterior means of various quantities (as a function of $X$) to visualize how well the models are fitting.

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.scatter(p11_true, np.mean(phat_11, axis=1), color="black")
ax1.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
ax2.scatter(p00_true, np.mean(phat_00, axis=1), color="black")
ax2.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
plt.show()

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex="none", sharey="none")
ax1.scatter(np.mean(phat_ac, axis=1), w_c + w_a, color="black")
ax1.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
ax1.set_xlim(0.5,1.1)
ax1.set_ylim(0.5,1.1)
ax2.scatter(np.mean(phat_a, axis=1), w_a, color="black")
ax2.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
ax2.set_xlim(0.1,0.4)
ax2.set_ylim(0.1,0.3)
ax3.scatter(np.mean(phat_c, axis=1), w_c, color="black")
ax3.axline((0, 0), slope=1, color="red", linestyle=(0, (3, 3)))
ax3.set_xlim(0.4,0.9)
ax3.set_ylim(0.4,0.8)
plt.show()

These plots are not as pretty as we might hope, but mostly this is a function of how difficult it is to learn conditional probabilities from binary outcomes. That we capture the trend broadly turns out to be adequate for estimating treatment effects. Fit does improve with simpler DGPs and larger training sets, as can be confirmed by experimentation with this script. 

Lastly, we can construct the estimate of the $ITT_c$ and compare it to the true value as well as the $Z=0$ and $Z=1$ complier average treatment effects (also called "local average treatment effects" or LATE). The key step in this process is to center our posterior on the identified interval (at each iteration of the sampler) at the value implied by a valid exclusion restriction. For some draws this will not be possible, as that value will be outside the identification region.

In [None]:
# Generate draws from the posterior of the treatment effect
# centered at the point-identified value under the exclusion restriction
itt_c = np.empty((ngrid, phat_c.shape[1]))
late = np.empty((ngrid, phat_c.shape[1]))
ss = 6
for j in range(phat_c.shape[1]):
    # Value of gamma11 implied by an exclusion restriction
    gamest11 = ((phat_a[:,j] + phat_c[:,j])/phat_c[:,j])*phat_11[:,j] - phat_10[:,j]*phat_a[:,j]/phat_c[:,j]

    # Identified region for gamma11
    lower11 = np.maximum(0., ((phat_a[:,j] + phat_c[:,j])/phat_c[:,j])*phat_11[:,j] - phat_a[:,j]/phat_c[:,j])
    upper11 = np.minimum(1., ((phat_a[:,j] + phat_c[:,j])/phat_c[:,j])*phat_11[:,j])

    # Center a beta distribution at gamma11, but restricted to (lower11, upper11)
    # do this by shifting and scaling the mean, drawing from a beta on (0,1), then shifting and scaling to the 
    # correct restricted interval
    m11 = (gamest11 - lower11)/(upper11 - lower11)

    # Parameters of the beta
    a1 = ss*m11
    b1 = ss*(1-m11)

    # When the corresponding mean is out-of-range, sample from a beta with mass piled near the violeted boundary
    a1[m11<0] = 1
    b1[m11<0] = 5
    
    a1[m11>1] = 5
    b1[m11>1] = 1

    # Value of gamma00 implied by an exclusion restriction
    gamest00 = ((phat_n[:,j] + phat_c[:,j])/phat_c[:,j])*phat_00[:,j] - phat_01[:,j]*phat_n[:,j]/phat_c[:,j]

    # Identified region for gamma00
    lower00 = np.maximum(0., ((phat_n[:,j] + phat_c[:,j])/phat_c[:,j])*phat_00[:,j] - phat_n[:,j]/phat_c[:,j])
    upper00 = np.minimum(1., ((phat_n[:,j] + phat_c[:,j])/phat_c[:,j])*phat_00[:,j])

    # Center a beta distribution at gamma00, but restricted to (lower00, upper00)
    # do this by shifting and scaling the mean, drawing from a beta on (0,1), then shifting and scaling to the 
    # correct restricted interval
    m00 = (gamest00 - lower00)/(upper00 - lower00)

    a0 = ss*m00
    b0 = ss*(1-m00)
    a0[m00<0] = 1
    b0[m00<0] = 5    
    a0[m00>1] = 5
    b0[m00>1] = 1

    # ITT and LATE    
    itt_c[:,j] = lower11 + (upper11 - lower11)*rng.beta(a=a1,b=b1,size=ngrid) - (lower00 + (upper00 - lower00)*rng.beta(a=a0,b=b0,size=ngrid))
    late[:,j] = gamest11 - gamest00

upperq = np.quantile(itt_c, q=0.975, axis=1)
lowerq = np.quantile(itt_c, q=0.025, axis=1)
upperq_er = np.quantile(late, q=0.975, axis=1)
lowerq_er = np.quantile(late, q=0.025, axis=1)

And now we can plot all of this, shading posterior quantiles with [pyplot's `fill` function](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.fill.html).

In [None]:
plt.plot(xgrid, itt_c_true, color = "black")
plt.ylim(-0.75, 0.05)
plt.fill(np.append(xgrid, xgrid[::-1]), np.append(lowerq, upperq[::-1]), color = (0.5,0.5,0,0.25))
plt.fill(np.append(xgrid, xgrid[::-1]), np.append(lowerq_er, upperq_er[::-1]), color = (0,0,0.5,0.25))

itt_c_est = np.mean(itt_c, axis=1)
late_est = np.mean(late, axis=1)

plt.plot(xgrid, late_est, color = "darkgrey")
plt.plot(xgrid, itt_c_est, color = "gold")
plt.plot(xgrid, LATE_true0, color = "black", linestyle = (0, (2, 2)))
plt.plot(xgrid, LATE_true1, color = "black", linestyle = (0, (4, 4)))
plt.plot(xgrid, itt_c_true, color = "black")

plt.show()

With a valid exclusion restriction the three black curves would all be the same. With no exclusion restriction, as we have here, the direct effect of $Z$ on $Y$ (the vaccine reminder on flu status) makes it so these three treatment effects are different. Specifically, the $ITT_c$ compares getting the vaccine *and* getting the reminder to not getting the vaccine *and* not getting the reminder. When both things have risk reducing impacts, we see a larger risk reduction over all values of $X$. Meanwhile, the two LATE effects compare the isolated impact of the vaccine among people that got the reminder and those that didn't, respectively. Here, not getting the reminder makes the vaccine more effective because the risk reduction is as a fraction of baseline risk, and the reminder reduces baseline risk in our DGP. 

We see also that the posterior mean of the $ITT_c$ estimate (gold) is very similar to the posterior mean under the assumption of an exclusion restriction (gray). This is by design...they will only deviate due to Monte Carlo variation or due to the rare situations where the exclusion restriction is incompatible with the identification interval. 

By changing the sample size and various aspects of the DGP this script allows us to build some intuituion for how aspects of the DGP affect posterior inferences, particularly how violates of assumptions affect accuracy and posterior uncertainty.

# References

Hahn, P Richard, Jared S Murray, and Ioanna Manolopoulou. 2016. “A Bayesian Partial Identification Approach to Inferring the Prevalence of Accounting Misconduct.” Journal of the American Statistical Association 111 (513): 14–26.

Hirano, Keisuke, Guido W. Imbens, Donald B. Rubin, and Xiao-Hua Zhou. 2000. “Assessing the Effect of an Influenza Vaccine in an Encouragement Design.” Biostatistics 1 (1): 69–88. https://doi.org/10.1093/biostatistics/1.1.69.

McDonald, Clement J, Siu L Hui, and William M Tierney. 1992. “Effects of Computer Reminders for Influenza Vaccination on Morbidity During Influenza Epidemics.” MD Computing: Computers in Medical Practice 9 (5): 304–12.

Richardson, Thomas S., Robin J. Evans, and James M. Robins. 2011. “Transparent Parametrizations of Models for Potential Outcomes.” In Bayesian Statistics 9. Oxford University Press. https://doi.org/10.1093/acprof:oso/9780199694587.003.0019.