In [1]:
%env CUDA_VISIBLE_DEVICES=-1

import os

import contextlib
from typing import Callable, Mapping, TypeVar, Union

import pyro.distributions.constraints as constraints
import torch

from chirho.explainable.handlers.components import (
    consequent_neq,
    random_intervention,
    sufficiency_intervention,
    undo_split,
)
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.interventional.handlers import do
from chirho.interventional.ops import Intervention
from chirho.observational.handlers.condition import Factors

S = TypeVar("S")
T = TypeVar("T")

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

import matplotlib.pyplot as plt

import pandas as pd
from typing import Callable, Mapping, TypeVar, Union


from chirho.observational.handlers import condition
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.explainable.handlers import SplitSubsets, SearchForExplanation #, SearchForNS


from chirho.indexed.ops import (IndexSet, gather, indices_of) 
from chirho.interventional.handlers import do

env: CUDA_VISIBLE_DEVICES=-1


In [2]:
def ff_disjunctive():
        match_dropped = pyro.sample("match_dropped", dist.Bernoulli(0.7)) # notice uneven probs here
        lightning = pyro.sample("lightning", dist.Bernoulli(0.4))

        forest_fire = pyro.deterministic("forest_fire", torch.max(match_dropped, lightning), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

observations = {"match_dropped": torch.tensor(1.), 
                "lightning": torch.tensor(0.),
                "forest_fire": torch.tensor(1.)}

antecedents = {"match_dropped": torch.tensor(0.0)} 

witnesses = {} # ignore witnesses for now

consequents = {"forest_fire": constraints.boolean}


In [3]:

with MultiWorldCounterfactual() as mwc:
    with do(actions = {"match_dropped": (torch.tensor(0.0), torch.tensor(8.0))}):
        with condition(data = observations):
            with pyro.poutine.trace() as int_trace:
                            ff_disjunctive()

print(int_trace.trace.nodes['match_dropped']['value'])  

with mwc:
    print(indices_of(int_trace.trace.nodes['match_dropped']['value']))

tensor([[[[[1.]]]],



        [[[[0.]]]],



        [[[[8.]]]]])
IndexSet({'match_dropped': {0, 1, 2}})


In [4]:
with MultiWorldCounterfactual() as ffd_mwc:  # needed to keep track of multiple scenarios
    with SearchForExplanation(antecedents = antecedents, 
                              witnesses = witnesses,
                              consequents = consequents,
                              consequent_scale= 1e-8):
        with condition(data = observations):
            with pyro.plate("sample", 10): # run a few times
                with pyro.poutine.trace() as ffd_tr:
                    ff_disjunctive()

ffd_tr.trace.compute_log_prob() 
ffd_nd = ffd_tr.trace.nodes

with ffd_mwc: 
    original_intervened = gather(ffd_nd['__consequent_forest_fire']['log_prob'], 
                IndexSet(**{'match_dropped': {1}}))
    print(original_intervened)

tensor([[[[[-inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0.]]]]])


In [5]:
@contextlib.contextmanager
def SearchForNS(
    antecedents: Union[
        Mapping[str, Intervention[T]],
        Mapping[str, constraints.Constraint],
    ],
    witnesses: Union[
        Mapping[str, Intervention[T]], Mapping[str, constraints.Constraint]
    ],
    consequents: Union[
        Mapping[str, Callable[[T], Union[float, torch.Tensor]]],
        Mapping[str, constraints.Constraint],
    ],
    *,
    antecedent_bias: float = 0.0,
    witness_bias: float = 0.0,
    consequent_scale: float = 1e-2,
    antecedent_prefix: str = "__antecedent_",
    witness_prefix: str = "__witness_",
    consequent_prefix: str = "__consequent_",
):
    """
    # TODO revise this docstring
    Effect handler used for causal explanation search. On each run:

      1. The antecedent nodes are intervened on with the values in ``antecedents`` \
        using :func:`~chirho.counterfactual.ops.split` . \
        Unless alternative interventions are provided, \
        counterfactual values are uniformly sampled for each antecedent node \
        using :func:`~chirho.explainable.internals.uniform_proposal` \
        given its support as a :class:`~pyro.distributions.constraints.Constraint`.

      2. These interventions are randomly :func:`~chirho.explainable.ops.preempt`-ed \
        using :func:`~chirho.explainable.handlers.undo_split` \
        by a :func:`~chirho.explainable.handlers.SplitSubsets` handler.

      3. The witness nodes are randomly :func:`~chirho.explainable.ops.preempt`-ed \
        to be kept at the values given in ``witnesses``.

      4. A :func:`~pyro.factor` node is added tracking whether the consequent nodes differ \
        between the factual and counterfactual worlds.

    :param antecedents: A mapping from antecedent names to interventions or to constraints.
    :param witnesses: A mapping from witness names to interventions or to constraints.
    :param consequents: A mapping from consequent names to factor functions or to constraints.
    """
    if antecedents and isinstance(
        next(iter(antecedents.values())),
        constraints.Constraint,
    ):
        
        antecedents_supports = {a: s for a, s in antecedents.items()}

        antecedents = {
            a: (
                random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}"),
                sufficiency_intervention(s, antecedents.keys()),
            )
            for a, s in antecedents_supports.items()
        }
    else:
        
        antecedents_supports = {a: constraints.boolean for a in antecedents.keys()}
        # TODO generalize to non-scalar antecedents
 
        
        antecedents = {
            a: (
                antecedents[a], 
                sufficiency_intervention(s, antecedents.keys())
                )
             
            for a, s in antecedents_supports.items()
        }


    if witnesses and isinstance(
        next(iter(witnesses.values())),
        constraints.Constraint,
    ):
        witnesses = {
            w: undo_split(s, antecedents=list(antecedents.keys()))
            for w, s in witnesses.items()
        }

    if consequents and isinstance(
        next(iter(consequents.values())),
        constraints.Constraint,
    ):
        consequents_neq = {
            c: consequent_neq(
                support=s,
                antecedents=list(antecedents.keys()),
                scale=consequent_scale, #TODO allow for different scales for neq and eq
            )
            for c, s in consequents.items()
        }

    if len(consequents_neq) == 0:
        raise ValueError("must have at least one consequent")

    if len(antecedents) == 0:
        raise ValueError("must have at least one antecedent")

    if set(consequents_neq.keys()) & set(antecedents.keys()):
        raise ValueError("consequents and possible antecedents must be disjoint")

    if set(consequents_neq.keys()) & set(witnesses.keys()):
        raise ValueError("consequents and possible witnesses must be disjoint")

    antecedent_handler = SplitSubsets(
        supports=antecedents_supports,
        actions=antecedents,
        bias=antecedent_bias,
        prefix=antecedent_prefix,
    )

    witness_handler: Preemptions = Preemptions(
        actions=witnesses, bias=witness_bias, prefix=witness_prefix
    )

    consequent_neq_handler = Factors(factors=consequents_neq, prefix=f"{consequent_prefix}_neq")

    with antecedent_handler, witness_handler, consequent_neq_handler:
        yield


In [6]:
with MultiWorldCounterfactual() as ffd_ns_mwc:  # needed to keep track of multiple scenarios
    with SearchForNS(antecedents = antecedents,
    antecedent_bias = -0.5, 
                              witnesses = witnesses,
                              consequents = consequents,
                              consequent_scale= 1e-8):
        with condition(data = observations):
                with pyro.poutine.trace() as ffd_ns_tr:
                    ff_disjunctive()

ffd_ns_tr.trace.compute_log_prob() 
ffd_ns_nd = ffd_ns_tr.trace.nodes


with mwc:
    print("sufficiency_interv", sufficiency_intervention(
        constraints.boolean, ["match_dropped"])(
        ffd_ns_nd['match_dropped']['value'])
    )
    print(indices_of(ffd_ns_nd['forest_fire']['value']))


print(ffd_ns_nd['match_dropped']['value'])



NameError: name 'consequent_handler' is not defined