# Doubly robust estimation with Chirho

## Setup

In [1]:
import collections
import math
import seaborn as sns
import matplotlib.pyplot as plt
from typing import Callable, Dict

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoNormal
from pyro.infer import Predictive
from typing import Callable, Dict, List, Optional, Tuple, Union

from chirho.observational.handlers import condition
from chirho.interventional.handlers import do
from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.indexed.ops import IndexSet, gather

pyro.settings.set(module_local_params=True)

sns.set_style("white")

# pyro.set_rng_seed(321) # for reproducibility

In [2]:
gaussian_link = lambda mu: dist.Normal(mu, 1.)
bernoulli_link = lambda mu: dist.Bernoulli(logits=mu)

class HighDimLinearModel(pyro.nn.PyroModule):
    def __init__(self, N: int, p: int, link_fn: Callable[..., dist.Distribution] = gaussian_link):
        super().__init__()
        self.N = N
        self.p = p
        self.link_fn = link_fn
    
    def sample_outcome_weights(self):
        return pyro.sample("outcome_weights", dist.Normal(0.,  1./math.sqrt(self.p)).expand((self.p, )).to_event(1))
    
    def sample_propensity_weights(self):
        return pyro.sample("propensity_weights", dist.Normal(0., 1./math.sqrt(self.p)).expand((self.p,)).to_event(1))

    def sample_treatment_weight(self):
        return pyro.sample("treatment_weight", dist.Normal(0., 1.))
    
    def sample_covariate_loc_scale(self):
        loc = pyro.sample("covariate_loc", dist.Normal(0., 1.).expand((self.p,)).to_event(1))
        scale = pyro.sample("covariate_scale", dist.LogNormal(0, 1).expand((self.p,)).to_event(1))
        return loc, scale

    def forward(self):
        outcome_weights = self.sample_outcome_weights()
        propensity_weights = self.sample_propensity_weights()
        tau = self.sample_treatment_weight()
        x_loc, x_scale = self.sample_covariate_loc_scale()
        with pyro.plate("obs", self.N, dim=-1):
            X = pyro.sample("X", dist.Normal(x_loc, x_scale).to_event(1))
            A = pyro.sample("A", dist.Bernoulli(logits=torch.einsum("...np,...p->...n", X, propensity_weights)))
            return pyro.sample("Y", self.link_fn(torch.einsum("...np,...p->...n", X, outcome_weights) + A * tau))
        

class BenchmarkLinearModel(HighDimLinearModel):
    def __init__(self, N: int, p: int, link_fn: Callable, alpha: int, beta: int):
        super().__init__(N, p, link_fn)
        self.alpha = alpha # sparsity of propensity weights
        self.beta = beta # sparisty of outcome weights
    
    def sample_outcome_weights(self):
        outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p)
        outcome_weights[self.beta:] = 0.
        return outcome_weights
    
    def sample_propensity_weights(self):
        propensity_weights = 1 / math.sqrt(4 * self.alpha) * torch.ones(self.p)
        propensity_weights[self.alpha:] = 0.
        return propensity_weights

    def sample_treatment_weight(self):
        return torch.tensor(0.)
    
    def sample_covariate_loc_scale(self):
        return torch.zeros(self.p), torch.ones(self.p)
    
# model here implictly already is conditioning on X_train and on particular theta
def ATE(model: Callable[[], torch.Tensor], num_samples: int = 100) -> torch.Tensor:
    """Compute the average treatment effect of a model."""
    @pyro.plate("num_samples", num_samples, dim=-2)
    def _ate_model():
        with MultiWorldCounterfactual():
            with do(actions=dict(A=(torch.tensor(0.), torch.tensor(1.)))):
                Ys = model()
            Y0 = gather(Ys, IndexSet(A={1}), event_dim=0)
            Y1 = gather(Ys, IndexSet(A={2}), event_dim=0)
            return pyro.deterministic("ATE", (Y1 - Y0).mean(dim=-1, keepdim=True))
    
    return _ate_model().mean(dim=-2, keepdim=True).squeeze()


def ATE_2(model: Callable[[], torch.Tensor], num_samples: int = 100) -> torch.Tensor:
    """Compute the average treatment effect of a model."""
    @pyro.plate("num_samples", num_samples, dim=-2)
    def _ate_model():
        with do(actions=dict(A=torch.tensor(0.))):
            Y0 = model()
        with do(actions=dict(A=torch.tensor(1.))):
            Y1 = model()
        return pyro.deterministic("ATE", (Y1 - Y0).mean(dim=-1, keepdim=True))
    
    return _ate_model().mean(dim=-2, keepdim=True).squeeze()


In [3]:
def flatten_dict(d: Dict[str, torch.tensor]) -> torch.tensor:
    """
    Flatten a dictionary of tensors into a single vector.
    """
    return torch.cat([v.flatten() for k, v in d.items()])


def unflatten_dict(x: torch.tensor, d: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
    """
    Unflatten a vector into a dictionary of tensors.
    """
    return collections.OrderedDict(zip(
        d.keys(), [v_flat.reshape(v.shape) for v, v_flat in zip(d.values(), torch.split(x, [v.numel() for k, v in d.items()]))]
    ))


def one_step_correction(
    target_functional: Callable[[Callable], torch.tensor],
    model: Callable[[], torch.tensor],
    theta_hat: Dict[str, torch.tensor],
    X_train: Dict[str, torch.tensor],
    X_test: Dict[str, torch.tensor],
    *,
    eps_fisher: float = 1e-5,
) -> torch.tensor:
    """
    One step correction for a given target functional.
    """
    theta_hat = collections.OrderedDict((k, theta_hat[k]) for k in sorted(theta_hat.keys()))
    model_theta_hat = condition(data=theta_hat)(model)

    plug_in = target_functional(model_theta_hat) + (0 * sum(theta_hat[k].sum() for k in {"propensity_weights"}))
    plug_in_grads = collections.OrderedDict(zip(theta_hat.keys(), torch.autograd.grad(plug_in, theta_hat.values())))
    
    # compute the score function for the new data
    model_theta_hat_test = condition(data=X_test)(model_theta_hat)
    log_likelihood_test = pyro.poutine.trace(model_theta_hat_test).get_trace().log_prob_sum() / X_test[next(iter(X_test))].shape[0]
    scores = collections.OrderedDict(zip(theta_hat.keys(), torch.autograd.grad(log_likelihood_test, theta_hat.values())))

    # compute the fisher information matrix for the model, not the target functional
    # we use the training data to estimate the fisher information matrix along with theta_hat itself
    def _f_hess(flat_theta: torch.tensor) -> torch.tensor:
        theta = unflatten_dict(flat_theta, theta_hat)
        model_theta_hat_fisher = condition(data=X_train)(condition(data=theta)(model))
        log_likelihood_fisher = pyro.poutine.trace(model_theta_hat_fisher).get_trace().log_prob_sum() / X_train[next(iter(X_train))].shape[0]
        return log_likelihood_fisher

    fisher_info_approx = torch.autograd.functional.hessian(_f_hess, flatten_dict(theta_hat))


    # compute the correction
    plug_in_grad = flatten_dict(plug_in_grads)
    print(plug_in_grads)
    score = flatten_dict(scores)
    inverse_fisher_info = torch.inverse(fisher_info_approx + eps_fisher * torch.eye(fisher_info_approx.shape[0]))
    return torch.einsum("i,ij,j->", plug_in_grad, inverse_fisher_info, score)

Below we generate synthetic data as in Figure 4b of Kennedy (2022).

In [4]:
p = 500
alpha = 50
beta = 50
N_train = 200
N_test = N_train  # 500  # TODO refactor model and ATE to not require N_test == N_train
benchmark_model_train = BenchmarkLinearModel(N_train, p, gaussian_link, alpha, beta)
benchmark_model_test = BenchmarkLinearModel(N_test, p, gaussian_link, alpha, beta)

with pyro.poutine.trace() as train_tr:
    benchmark_model_train()

with pyro.poutine.trace() as test_tr:
    benchmark_model_test()

D_train = {k: train_tr.trace.nodes[k]["value"] for k in ["X", "A", "Y"]}
D_test = {k: test_tr.trace.nodes[k]["value"] for k in ["X", "A", "Y"]}

In [5]:
# Fit model to training data (uncorrected)
class ConditionModelTrain(HighDimLinearModel):
    def forward(self):
        with condition(data=D_train):
            return super().forward()
    
model_train = ConditionModelTrain(N_train, p, gaussian_link)
guide_train = pyro.infer.autoguide.AutoDelta(model_train)
elbo = pyro.infer.Trace_ELBO()(model_train, guide_train)

# initialize parameters
elbo()

adam = torch.optim.Adam(elbo.parameters(), lr=0.03)

# Do gradient steps
for step in range(2000):
    adam.zero_grad()
    loss = elbo()
    loss.backward()
    adam.step()
    if step % 250 == 0:
        print("[iteration %04d] loss: %.4f" % (step, loss))

[iteration 0000] loss: 159499.7969
[iteration 0250] loss: 140658.6406
[iteration 0500] loss: 140658.6250
[iteration 0750] loss: 140658.8125
[iteration 1000] loss: 140658.9688
[iteration 1250] loss: 140658.6875
[iteration 1500] loss: 140658.8594
[iteration 1750] loss: 140659.0469


In [6]:
theta_hat = {k: v.clone().detach().requires_grad_(True) for k, v in guide_train().items()}
print(theta_hat.keys(), theta_hat["treatment_weight"])

dict_keys(['outcome_weights', 'propensity_weights', 'treatment_weight', 'covariate_loc', 'covariate_scale']) tensor(0.0322, requires_grad=True)


In [7]:
model_test = condition(data=theta_hat)(HighDimLinearModel(N_test, p, gaussian_link))

ATE_plugin = ATE_2(model_test, num_samples=1000)
print("ATE plugin", ATE_plugin)

ATE plugin tensor(0.0330, grad_fn=<SqueezeBackward0>)


In [8]:
ATE_correction = one_step_correction(
    lambda m: ATE_2(m, num_samples=1000),
    HighDimLinearModel(N_test, p, gaussian_link),
    theta_hat,
    D_train,
    D_test,
    eps_fisher=1e-10,
)
ATE_onestep = ATE_plugin + ATE_correction
print(ATE_plugin, ATE_correction, ATE_onestep)

OrderedDict([('covariate_loc', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0