In [31]:
from typing import Optional, Callable
import math

import pyro.distributions as dist
import torch

import pyro
from chirho.counterfactual.handlers.counterfactual import \
    MultiWorldCounterfactual
from chirho.explainable.handlers import SearchForExplanation
from chirho.explainable.handlers.components import ExtractSupports


In [32]:
def model():
    a = pyro.sample("a", dist.Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)))
    b = pyro.sample("b", dist.Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)))

In [33]:
with ExtractSupports() as s:
    model()

query = SearchForExplanation(
            supports=s.supports,
            alternatives={"a": torch.tensor(0.5)},
            antecedents={"a": torch.tensor(-0.5)},
            antecedent_bias=0.0,
            witnesses={},
            consequents={"b": torch.tensor(0.0)},
            consequent_scale=1e-8,
        )(model)

How can I compute the probability that `a=-0.5` is a sufficienct and necessary cause of `b=0` using `SearchForExplanation`?

In [34]:
def importance_infer(
    model: Optional[Callable] = None, *, num_samples: int
):
    
    if model is None:
        return lambda m: importance_infer(m, num_samples=num_samples)

    def _wrapped_model(
        *args,
        **kwargs
    ):

        guide = pyro.poutine.block(hide_fn=lambda msg: msg["is_observed"])(model)

        max_plate_nesting = 9  # TODO guess

        with pyro.poutine.block(), MultiWorldCounterfactual() as mwc_imp:
            log_weights, importance_tr, _ = pyro.infer.importance.vectorized_importance_weights(
                model,
                guide,
                *args,
                num_samples=num_samples,
                max_plate_nesting=max_plate_nesting,
                normalized=False,
                **kwargs
            )

        return torch.logsumexp(log_weights, dim=0) - math.log(num_samples), importance_tr, mwc_imp, log_weights

    return _wrapped_model