In [134]:
import torch
from chirho.interventional.handlers import Interventions
from chirho.counterfactual.handlers import MultiWorldCounterfactual
import functools
import chirho

import chirho.interventional.ops
from chirho.robust.handlers.predictive import PredictiveFunctional, PredictiveModel
from chirho.robust.handlers.estimators import MonteCarloInfluenceEstimator, one_step_corrected_estimator
from chirho.robust.ops import influence_fn
import pyro
from typing import Mapping
from tests.robust.robust_fixtures import SimpleModel, SimpleGuide, HighDimLinearModel
import pyro.distributions as dist
from chirho.observational.handlers import condition

from typing import Callable, Optional, Tuple
from fractions import Fraction

import functools
import torch
import math
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time

import pyro
import pyro.distributions as dist
from pyro.infer import Predictive
import pyro.contrib.gp as gp

from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.indexed.ops import IndexSet, gather
from chirho.interventional.handlers import do
from chirho.robust.internals.utils import ParamDict
from chirho.robust.handlers.estimators import one_step_corrected_estimator 
from chirho.robust.ops import influence_fn
from chirho.robust.handlers.predictive import PredictiveModel, PredictiveFunctional
from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood


pyro.settings.set(module_local_params=True)

sns.set_style("white")

pyro.set_rng_seed(32891) # for reproducibility

In [4]:
def test_estimator_smoke(
    model,
    guide,
    obs_names,
    max_plate_nesting,
    num_samples_outer,
    num_samples_inner,
    cg_iters,
    num_predictive_samples,
    estimation_method,
):
    model = model()
    guide = guide(model)
    model(), guide()  # initialize

    with torch.no_grad():
        test_datum = {
            k: v[0]
            for k, v in pyro.infer.Predictive(
                model, num_samples=2, return_sites=obs_names, parallel=True
            )().items()
        }

    estimator = estimation_method(
        functools.partial(PredictiveFunctional, num_samples=num_predictive_samples),
        test_datum,
    )(PredictiveModel(model, guide))

    with MonteCarloInfluenceEstimator(
        max_plate_nesting=max_plate_nesting,
        num_samples_outer=num_samples_outer,
        num_samples_inner=num_samples_inner,
        cg_iters=cg_iters,
    ):
        estimate_on_test: Mapping[str, torch.Tensor] = estimator()
    assert len(estimate_on_test) > 0
    for k, v in estimate_on_test.items():
        assert not torch.isnan(v).any(), f"{estimation_method} for {k} had nans"
        assert not torch.isinf(v).any(), f"{estimation_method} for {k} had infs"
        assert not torch.isclose(
            v, torch.zeros_like(v)
        ).all(), f"{estimation_method} estimator for {k} was zero"
    
    print(estimate_on_test)


In [5]:
# with torch.no_grad():
test_estimator_smoke(
    model=SimpleModel,
    guide=lambda _: SimpleGuide(),
    obs_names={"y"},
    max_plate_nesting=1,
    num_samples_outer=1000,
    num_samples_inner=10000,
    cg_iters=2000,
    num_predictive_samples=2000,
    estimation_method=one_step_corrected_estimator,
)



{'y': tensor([[[[[[[-0.0875,  3.0717,  3.0262]]]]],




         [[[[[-0.3362, -1.3289, -0.5850]]]]],




         [[[[[ 1.1805,  1.0019, -1.0683]]]]],




         ...,




         [[[[[ 0.8283,  1.1114,  0.3105]]]]],




         [[[[[-1.2064, -1.1788, -1.5024]]]]],




         [[[[[-0.2853, -0.4361,  0.8307]]]]]]], grad_fn=<AddBackward0>)}


In [59]:
hdl_model = HighDimLinearModel(p=10)
hdl_model()

tensor([-0.5445])

In [89]:
def mc_ate_functional(model, n, binary_treatment_name="A"):
    with pyro.plate("n", n, dim=-2):
        with Interventions({binary_treatment_name: 0}):
            y0 = model()
        with Interventions({binary_treatment_name: 1}):
            y1 = model()
        
    return y1.mean() - y0.mean()

In [88]:
mc_ate_functional(hdl_model, 1000)

tensor(0.0173)

In [125]:
with pyro.plate('data', 30, dim=-2):
    full_sample = PredictiveFunctional(hdl_model)()
    observed_sample = dict(
        Y=full_sample['Y'],
        A=full_sample['A'],
        X=full_sample['X']
    )

In [127]:
for k, v in observed_sample.items():
    print(k, v.shape)

Y torch.Size([30, 30, 1])
A torch.Size([30, 30, 1])
X torch.Size([30, 1, 10])


In [128]:
conditioned_hdl_model = chirho.observational.handlers.condition(hdl_model, observed_sample)

In [133]:
svi = pyro.infer.SVI(
    hdl_model,
    pyro.infer.autoguide.AutoDiagonalNormal(hdl_model),
    pyro.optim.Adam({'lr': 0.01}),
    pyro.infer.TraceEnum_ELBO()
)
for _ in range(1000):
    svi.step()

ValueError: Continuous inference cannot handle discrete sample site 'A'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['A'])).
Trace Shapes:
 Param Sites:
Sample Sites: