In [24]:
%env CUDA_VISIBLE_DEVICES=-1
from typing import Callable, Dict, List, Optional

import math
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch
from chirho.counterfactual.handlers.counterfactual import \
    MultiWorldCounterfactual
from chirho.explainable.handlers import ExtractSupports, SearchForExplanation
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers import condition
from chirho.observational.handlers.soft_conditioning import soft_eq, KernelSoftConditionReparam

pyro.settings.set(module_local_params=True)

env: CUDA_VISIBLE_DEVICES=-1


In [25]:
def importance_infer(
    model: Optional[Callable] = None, *, num_samples: int
):
    
    if model is None:
        return lambda m: importance_infer(m, num_samples=num_samples)

    def _wrapped_model(
        *args,
        **kwargs
    ):

        guide = pyro.poutine.block(hide_fn=lambda msg: msg["is_observed"])(model)

        max_plate_nesting = 9  # TODO guess

        with pyro.poutine.block(), MultiWorldCounterfactual() as mwc:
            log_weights, importance_tr, _ = pyro.infer.importance.vectorized_importance_weights(
                model,
                guide,
                *args,
                num_samples=num_samples,
                max_plate_nesting=max_plate_nesting,
                normalized=False,
                **kwargs
            )

        return torch.logsumexp(log_weights, dim=0) - math.log(num_samples), importance_tr, mwc, log_weights

    return _wrapped_model

In [26]:
def example():
    A = pyro.sample("A", dist.Bernoulli(0.5))
    B = pyro.sample("B", dist.Bernoulli(0.5))
    C = pyro.sample("C", dist.Bernoulli(A))
    return {"A": A, "B": B, "C": C}

with ExtractSupports() as extract_supports:
    example()

In [30]:
query = SearchForExplanation(
    supports=extract_supports.supports,
    antecedents={"A": 1.0, "B": 1.0},
    consequents={"C": torch.tensor(1.0)},
    witnesses={},
    alternatives={"A": 0.0, "B": 0.0},
    antecedent_bias=0.4,
    consequent_scale=1e-5,
)(example)

logp, trace, mwc, log_weights = importance_infer(num_samples=10000)(query)()
print(torch.exp(logp))

tensor(0.1057)


In [31]:
mask_intervened = (trace.nodes["__cause____antecedent_B"]["value"] == 0)
print(torch.sum(torch.exp(log_weights) * mask_intervened.float().squeeze())/mask_intervened.float().sum())
# Marginalizing over the fact that B was intervened on gives the following answer which accounts for the causal role of the set {A = 1, B = 1}

tensor(0.1077)


In [29]:
mask_intervened = (trace.nodes["__cause____antecedent_B"]["value"] == 0) & (trace.nodes["__cause____antecedent_A"]["value"] == 1)
print(torch.sum(torch.exp(log_weights) * mask_intervened.float().squeeze())/mask_intervened.float().sum())
# Marginalizing over the fact that B was intervened on and A was not gives the following answer which agrees with the fact that B has no causal role


tensor(5.1220e-06)
