In [14]:
import functools
import itertools
import contextlib
from typing import Callable, Iterable, TypeVar, Mapping

import pyro
import torch  # noqa: F401

from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")

import pyro
import pyro.distributions as dist
import pyro.infer
import pytest
import torch
from scipy.stats import spearmanr

from chirho.counterfactual.handlers import MultiWorldCounterfactual
from chirho.counterfactual.handlers.counterfactual import Preemptions
from chirho.counterfactual.handlers.explanation import (
    consequent_differs,
    random_intervention,
    undo_split,
    uniform_proposal,
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors
from chirho.interventional.ops import Intervention
from chirho.interventional.handlers import do




In [15]:


@functools.singledispatch
def uniform_proposal(
    support: pyro.distributions.constraints.Constraint,
    **kwargs,
) -> pyro.distributions.Distribution:
    """
    This function heuristically constructs a probability distribution over a specified
    support. The choice of distribution depends on the type of support provided.

    - If the support is `real`, it creates a wide Normal distribution
      and standard deviation, defaulting to (0,100).
    - If the support is `boolean`, it creates a Bernoulli distribution with a fixed logit of 0,
      corresponding to success probability .5.
    - If the support is an `interval`, the transformed distribution is centered around the
      midpoint of the interval.

    :param support: The support used to create the probability distribution.
    :param kwargs: Additional keyword arguments.
    :return: A uniform probability distribution over the specified support.
    """
    if support is pyro.distributions.constraints.real:
        return pyro.distributions.Normal(0, 100).mask(False)
    elif support is pyro.distributions.constraints.boolean:
        return pyro.distributions.Bernoulli(logits=torch.zeros(()))
    else:
        tfm = pyro.distributions.transforms.biject_to(support)
        base = uniform_proposal(pyro.distributions.constraints.real, **kwargs)
        return pyro.distributions.TransformedDistribution(base, tfm)


@uniform_proposal.register
def _uniform_proposal_indep(
    support: pyro.distributions.constraints.independent,
    *,
    event_shape: torch.Size = torch.Size([]),
    **kwargs,
) -> pyro.distributions.Distribution:
    """
    This constructs a probability distribution with independent dimensions
    over a specified support. The choice of distribution depends on the type of support provided
    (see the documentation for `uniform_proposal`).

    :param support: The support used to create the probability distribution.
    :param event_shape: The event shape specifying the dimensions of the distribution.
    :param kwargs: Additional keyword arguments.
    :return: A probability distribution with independent dimensions over the specified support.

    Example:
    ```
    indep_constraint = pyro.distributions.constraints.independent(
    pyro.distributions.constraints.real, reinterpreted_batch_ndims=2)
    dist = uniform_proposal(indep_constraint, event_shape=torch.Size([2, 3]))
    with pyro.plate("data", 3):
        samples_indep = pyro.sample("samples_indep", dist.expand([4, 2, 3]))
    ```
    """

    d = uniform_proposal(support.base_constraint, event_shape=event_shape, **kwargs)
    return d.expand(event_shape).to_event(support.reinterpreted_batch_ndims)


@uniform_proposal.register
def _uniform_proposal_integer(
    support: pyro.distributions.constraints.integer_interval,
    **kwargs,
) -> pyro.distributions.Distribution:
    """
    This constructs a uniform categorical distribution over an integer_interval support
    where the lower bound is 0 and the upper bound is specified by the support.

    :param support: The integer_interval support with a lower bound of 0 and a specified upper bound.
    :param kwargs: Additional keyword arguments.
    :return: A categorical probability distribution over the specified integer_interval support.

    Example:
    ```
    constraint = pyro.distributions.constraints.integer_interval(0, 2)
    dist = _uniform_proposal_integer(constraint)
    samples = dist.sample(torch.Size([100]))
    print(dist.probs.tolist())
    ```
    """
    if support.lower_bound != 0:
        raise NotImplementedError(
            "integer_interval with lower_bound > 0 not yet supported"
        )
    n = support.upper_bound - support.lower_bound + 1
    return pyro.distributions.Categorical(probs=torch.ones((n,)))


def random_intervention(
    support: pyro.distributions.constraints.Constraint,
    name: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
    """
    Creates a random `pyro`sample` function for a single sample site, determined by
    by the distribution support, and site name.

    :param support: The support constraint for the sample site..can take.
    :param name: The name of the sample site.

    :return: A `pyro.sample` function that takes a torch.Tensor as input
        and returns a random sample over the pre-specified support of the same
        event shape as the input tensor.

    Example:
    ```
    support = pyro.distributions.constraints.real
    name = "real_sample"
    intervention_fn = random_intervention(support, name)
    random_sample = intervention_fn(torch.tensor(2.0))
    ```
    """

    def _random_intervention(value: torch.Tensor) -> torch.Tensor:
        event_shape = value.shape[len(value.shape) - support.event_dim :]
        proposal_dist = uniform_proposal(
            support,
            event_shape=event_shape,
        )
        return pyro.sample(name, proposal_dist)

    return _random_intervention


In [8]:
@contextlib.contextmanager
def SearchForCause(
    actions: Mapping[str, Intervention[T]],
    *,
    bias: float = 0.0,
    prefix: str = "__cause_split_",
):
    """
    A context manager used for a stochastic search of minimal but-for causes among potential interventions.
    On each run, nodes listed in `actions` are randomly seleted and intervened on with probability `.5 + bias`
    (that is, preempted with probability `.5-bias`). The sampling is achieved by adding stochastic binary preemption
    nodes associated with intervention candidates. If a given preemption node has value `0`, the corresponding
    intervention is executed. See tests in `tests/counterfactual/test_handlers_explanation.py` for examples.

    :param actions: A mapping of sites to interventions.
    :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
    :param prefix: A prefix used for naming additional preemption nodes. Defaults to "__cause_split_".
    """
    # TODO support event_dim != 0 propagation in factual_preemption
    preemptions = {
        antecedent: undo_split(antecedents=[antecedent])
        for antecedent in actions.keys()
    }

    with do(actions=actions):
        with Preemptions(actions=preemptions, bias=bias, prefix=prefix):
            yield






@contextlib.contextmanager
def ExplainCauses(
    antecedents: Mapping[str, Intervention[T]]
    | Mapping[str, pyro.distributions.constraints.Constraint],
    witnesses: Mapping[str, Intervention[T]] | Iterable[str],
    consequents: Mapping[str, Callable[[T], float | torch.Tensor]]
    | Iterable[str],
    *,
    antecedent_bias: float = 0.0,
    witness_bias: float = 0.0,
    consequent_eps: float = -1e8,
    antecedent_prefix: str = "__antecedent_",
    witness_prefix: str = "__witness_",
    consequent_prefix: str = "__consequent_",
):
    """
    Effect handler for causal explanation.

    :param antecedents: A mapping from antecedent names to interventions.
    :param witnesses: A mapping from witness names to interventions.
    :param consequents: A mapping from consequent names to factor functions.
    """
    if isinstance(
        next(iter(antecedents.values())),
        pyro.distributions.constraints.Constraint,
    ):
        antecedents = {
            a: random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}")
            for a, s in antecedents.items()
        }

    if not isinstance(witnesses, collections.abc.Mapping):
        witnesses = {
            w: undo_split(antecedents=list(antecedents.keys()))
            for w in witnesses
        }

    if not isinstance(consequents, collections.abc.Mapping):
        consequents = {
            c: consequent_differs(
                antecedents=list(antecedents.keys()), eps=consequent_eps
            )
            for c in consequents
        }

    if len(consequents) == 0:
        raise ValueError("must have at least one consequent")

    if len(antecedents) == 0:
        raise ValueError("must have at least one antecedent")

    if set(consequents.keys()) & set(antecedents.keys()):
        raise ValueError(
            "consequents and possible antecedents must be disjoint"
        )

    if set(consequents.keys()) & set(witnesses.keys()):
        raise ValueError("consequents and possible witnesses must be disjoint")

    antecedent_handler = PartOfCause(
        actions=antecedents, bias=antecedent_bias, prefix=antecedent_prefix
    )
    witness_handler = BiasedPreemptions(
        actions=witnesses, bias=witness_bias, prefix=witness_prefix
    )
    consequent_handler = Factors(factors=consequents, prefix=consequent_prefix)

    with antecedent_handler, witness_handler, consequent_handler:
        with pyro.poutine.trace() as logging_tr:
            yield logging_tr.trace


NameError: name 'Intervention' is not defined