# Doubly robust estimation with Chirho

## Outline

- [Setup](#setup)
- [Overview: Robust Causal Inference with Cut Modules](#overview:-robust-causal-inference-with-cut-modules)
- [Example: Synthetic data generation from a high-dimensional generalized linear model](#example:-synthetic-data-generation-from-a-high-dimensional-generalized-linear-model)
- [Effect estimation using cut modules](#effect-estimation-using-cut-modules)
- [References](#references)

## Setup

In [4]:
import math
import seaborn as sns
import matplotlib.pyplot as plt
from typing import Callable, Dict

import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoNormal
from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.observational.handlers.cut import SingleStageCut
from pyro.infer import Predictive
from typing import Callable, Dict, List, Optional, Tuple, Union

pyro.settings.set(module_local_params=True)

sns.set_style("white")

pyro.set_rng_seed(321) # for reproducibility

In [6]:
gaussian_link = lambda mu: dist.Normal(mu, 1.)
bernoulli_link = lambda mu: dist.Bernoulli(logits=mu)

class HighDimLinearModel(pyro.nn.PyroModule):
    def __init__(self, p: int, link_fn: Callable = gaussian_link):
        super().__init__()
        self.p = p
        self.link_fn = link_fn
    
    def sample_outcome_weights(self):
        return pyro.sample("outcome_weights", dist.Normal(0.,  1./math.sqrt(self.p)).expand((self.p, )).to_event(1))

    def sample_treatment_weight(self):
        return pyro.sample("treatment_weight", dist.Normal(0., 1.))

    def forward(self, X: torch.Tensor, A: torch.Tensor):
        N = X.shape[0]
        outcome_weights = self.sample_outcome_weights()
        tau = self.sample_treatment_weight()
        with pyro.plate("obs", N):
            pyro.sample("Y", self.link_fn(X @ outcome_weights + A * tau))
        

class BenchmarkLinearModel(HighDimLinearModel):
    def __init__(self, p: int, link_fn: Callable, alpha: int, beta: int):
        super().__init__(p, link_fn)
        self.alpha = alpha # sparsity of propensity weights
        self.beta = beta # sparisty of outcome weights
    
    def sample_outcome_weights(self):
        outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p)
        outcome_weights[self.beta:] = 0.
        return outcome_weights

    def sample_treatment_weight(self):
        return torch.tensor(0.)


In [3]:
# Think X_fisher should be the training data

def one_step_correction(model: Callable[[torch.tensor], None], X_fisher: Dict[str, torch.tensor], X_new: Dict[str, torch.tensor], theta_hat: torch.tensor, target_functional: Callable[[Callable], torch.tensor]) -> torch.tensor:
    """
    One step correction for a given target functional.
    """
    plug_in = target_functional(lambda: model(theta_hat))
    plug_in_grad = torch.autograd.grad(plug_in, theta_hat, create_graph=True)
    
    # compute the fisher information matrix for the model, not the target functional
    # we use the training data to estimate the fisher information matrix along with theta_hat itself
    log_likelihood_fisher = pyro.poutine.trace(condition(data=X_fisher)(lambda: model(theta_hat))).get_trace().log_prob_sum() / X_fisher[next(iter(X_fisher))].shape[0]
    fisher_info_approx = torch.autograd.hessian(log_likelihood_fisher, (theta_hat,))

    # compute the score function for the new data
    log_likelihood_new = pyro.poutine.trace(condition(data=X_new)(lambda: model(theta_hat))).get_trace().log_prob_sum() / X_new[next(iter(X_new))].shape[0]
    score = torch.autograd.grad(log_likelihood_new, theta_hat, create_graph=True)

    # compute the correction
    inverse_fisher_info = torch.inverse(fisher_info_approx)
    return plug_in_grad.T @ inverse_fisher_info @ score

Below we generate synthetic data as in Figure 4b of Kennedy (2022).

In [12]:
p = 500
alpha = 50
beta = 50
N_train = 200
N_test = 500
benchmark_model = BenchmarkLinearModel(p, gaussian_link, alpha, beta)
true_propensity_weights = 1 / math.sqrt(4 * alpha) * torch.ones(p)
true_propensity_weights[alpha:] = 0.
X_train = dist.Normal(0., 1.).expand((N_train, p)).to_event(1).sample()
A_train = dist.Bernoulli(logits=X_train @ true_propensity_weights).sample()
X_test = dist.Normal(0., 1.).expand((N_test, p)).to_event(1).sample()
A_test = dist.Bernoulli(logits=X_test @ true_propensity_weights).sample()

with pyro.poutine.trace() as training_data:
    benchmark_model(X_train, A_train)

with pyro.poutine.trace() as testing_data:
    benchmark_model(X_test, A_test)

Y_train = training_data.trace.nodes["Y"]["value"]
D_train = {"X": X_train, "A": A_train, "Y": Y_train}

Y_test = testing_data.trace.nodes["Y"]["value"]
D_test = {"X": X_test, "A": A_test, "Y": Y_test}

In [43]:
# Do ridge regression on the training data
ridge_penalty = .001
X_tilde_train = torch.cat([A_train.reshape(-1, 1), X_train], dim=1)
theta_ridge = torch.inverse(X_tilde_train.T @ X_tilde_train + ridge_penalty * torch.eye(p + 1)) @ X_tilde_train.T @ Y_train
theta_true = 1 / math.sqrt(beta) * torch.ones(p + 1)
theta_true[(beta + 1):] = 0.
theta_true[0] = 0. # true treatment effect is zero

In [44]:
((theta_ridge - theta_true) ** 2).sum()

tensor(2.1544)

In [45]:
(theta_true ** 2).sum()

tensor(1.0000)

In [46]:
theta_ridge[0] ** 2

tensor(0.0155)

In [None]:
class LinearGaussianModel(pyro.nn.PyroModule):
    def __init__(p):
        super().__init__()
        self.p = p
    
    def forward(self, theta: Dict[torch.tensor], X = None: torch.tensor, A=None: torch.tensor, Y=None: torch.tensor):
        if X is not None:
            assert len(X.shape) == 1, "X should be a vector"
            assert len(A.shape) == 1, "A should be a vector"
            assert len(Y.shape) == 0, "Y should be a scalar"
        
        noise_scale = theta['noise_scale']
        theta_propensity = theta['theta_propensity']
        theta_outcome = theta['theta_outcome']
        X = pyro.sample("X", dist.Normal(0., 1.).expand((self.p, )).to_event(1), obs=X) 
        A = pyro.sample("A", dist.Bernoulli(logits=X @ theta_propensity))
        mean = A * theta_outcome[0] + X.mv(theta_outcome)
        pyro.sample("Y", dist.Normal(mean, noise_scale)) 
    
    




def linear_gaussian_model_with_cov(theta, X, A):
    p = theta.shape[0]
    mean = A * theta[0] + X.mv(theta)
    pyro.sample("Y", dist.Normal(mean, 1.))    

X_fisher = {"X": X_train, "A": A_train}

In [None]:

one_step_correction(model: Callable[[torch.tensor], None], X_fisher: Dict[str, torch.tensor], X_new: Dict[str, torch.tensor], theta_hat: torch.tensor, target_functional: Callable[[Callable], torch.tensor])