In [1]:
import functools
import contextlib
import collections
from typing import Callable, Iterable, TypeVar, Mapping

import pyro
import torch  # noqa: F401

from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")

import pyro
import chirho
import pyro.distributions as dist
import pyro.infer
import pytest
import torch
import pandas as pd

from chirho.counterfactual.handlers.counterfactual import (MultiWorldCounterfactual,
        Preemptions)
from chirho.counterfactual.handlers.explanation import (
    SearchForCause,
    consequent_differs,
    random_intervention,
    undo_split,
    uniform_proposal,
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors, condition
from chirho.interventional.ops import Intervention, intervene
from chirho.interventional.handlers import do

In [2]:
@contextlib.contextmanager
def ExplainCauses(
    antecedents: Mapping[str, Intervention[T]]
    | Mapping[str, pyro.distributions.constraints.Constraint],
    witnesses: Mapping[str, Intervention[T]] | Iterable[str],
    consequents: Mapping[str, Callable[[T], float | torch.Tensor]]
    | Iterable[str],
    *,
    antecedent_bias: float = 0.0,
    witness_bias: float = 0.0,
    consequent_eps: float = -1e8,
    antecedent_prefix: str = "__antecedent_",
    witness_prefix: str = "__witness_",
    consequent_prefix: str = "__consequent_",
):
    """
    Effect handler for causal explanation.

    :param antecedents: A mapping from antecedent names to interventions.
    :param witnesses: A mapping from witness names to interventions.
    :param consequents: A mapping from consequent names to factor functions.
    """
    if isinstance(
        next(iter(antecedents.values())),
        pyro.distributions.constraints.Constraint,
    ):
        antecedents = {
            a: random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}")
            for a, s in antecedents.items()
        }

    if not isinstance(witnesses, collections.abc.Mapping):
        witnesses = {
            w: undo_split(antecedents=list(antecedents.keys()))
            for w in witnesses
        }

    if not isinstance(consequents, collections.abc.Mapping):
        consequents = {
            c: consequent_differs(
                antecedents=list(antecedents.keys()), eps=consequent_eps
            )
            for c in consequents
        }

    if len(consequents) == 0:
        raise ValueError("must have at least one consequent")

    if len(antecedents) == 0:
        raise ValueError("must have at least one antecedent")

    if set(consequents.keys()) & set(antecedents.keys()):
        raise ValueError(
            "consequents and possible antecedents must be disjoint"
        )

    if set(consequents.keys()) & set(witnesses.keys()):
        raise ValueError("consequents and possible witnesses must be disjoint")

    antecedent_handler = SearchForCause(
        actions=antecedents, bias=antecedent_bias, prefix=antecedent_prefix
    )
    witness_handler = Preemptions(
        actions=witnesses, bias=witness_bias, prefix=witness_prefix
    )
    consequent_handler = Factors(factors=consequents, prefix=consequent_prefix)

    with antecedent_handler, witness_handler, consequent_handler:
        with pyro.poutine.trace() as logging_tr:
            yield

### Stone-throwing

In [6]:
@pyro.infer.config_enumerate
def stones_model():        
    prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
    prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
    prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
    prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
    prob_bottle_shatters_if_sally = pyro.sample("prob_bottle_shatters_if_sally", dist.Beta(1, 1))
    prob_bottle_shatters_if_bill = pyro.sample("prob_bottle_shatters_if_bill", dist.Beta(1, 1))

    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))


    new_shp = torch.where(sally_throws == 1,prob_sally_hits, 0.0)

    sally_hits = pyro.sample("sally_hits",dist.Bernoulli(new_shp))

    new_bhp = torch.where(
        bill_throws.bool() & (~sally_hits.bool()),
        prob_bill_hits,
        torch.tensor(0.0),
    )

    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))
    

    new_bsp = torch.where(bill_hits.bool(), prob_bottle_shatters_if_bill,
            torch.where(sally_hits.bool(),prob_bottle_shatters_if_sally,torch.tensor(0.0),),)

    bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp))

    return {"sally_throws": sally_throws, "bill_throws": bill_throws,  "sally_hits": sally_hits,
            "bill_hits": bill_hits,  "bottle_shatters": bottle_shatters,}

stones_model.nodes = ["sally_throws","bill_throws", "sally_hits", "bill_hits","bottle_shatters",]

In [36]:
observations = {"prob_sally_throws": 1.0, 
                "prob_bill_throws": 1.0,
                "prob_sally_hits": 1.0,
                "prob_bill_hits": 1.0,
                "prob_bottle_shatters_if_sally": 1.0,
                "prob_bottle_shatters_if_bill": 1.0,
                "sally_throws": 1.0, "bill_throws": 1.0}

observations_tensorized = {k: torch.as_tensor(v) for k, v in observations.items()}

antecedents = {"sally_throws": 0.0}
antencedent_bias = 0.1
witnesses = ["bill_throws", "bill_hits"]
consequents = ["bottle_shatters"]

#TODO? fails silently when consequents is a string

In [34]:
with MultiWorldCounterfactual() as mwc:
    with ExplainCauses(antecedents = antecedents, antecedent_bias= antencedent_bias,
                        witnesses = witnesses,
                        consequents = consequents):
        with condition(data = observations_tensorized):
            with pyro.plate("sample", 15):
                with pyro.poutine.trace() as tr:
                    stones_model()

In [90]:
def gather_observed(value):
    _indices = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(nodes[candidate]["value"], event_dim=0)
        ]
    _int_can = gather(
    nodes[candidate]["value"], IndexSet(**{i: {0} for i in _indices}), event_dim=0,)
    return _int_can

def gather_intervened(value):
    _indices = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(nodes[candidate]["value"], event_dim=0)
        ]
    _int_can = gather(
    nodes[candidate]["value"], IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
    return _int_can


def get_table(trace):

    values_table = {}
    trace = tr
    nodes = trace.trace.nodes

    with mwc:

        for antecedent_str in antecedents.keys():
                
            obs_ant = gather_observed(nodes[antecedent_str]["value"])
            int_ant = gather_observed(nodes[antecedent_str]["value"])

            values_table[f"{antecedent_str}_obs"] = obs_ant.squeeze().tolist()
            values_table[f"{antecedent_str}_int"] = int_ant.squeeze().tolist()
            
            apr_ant = nodes[f"__antecedent_{antecedent_str}"]["value"]
            values_table[f"apr_{antecedent_str}"] = apr_ant.squeeze().tolist()
            
            values_table[f"apr_{antecedent_str}_lp"] = nodes[f"__antecedent_{antecedent_str}"]["fn"].log_prob(apr_ant)

        for candidate in witnesses:
            obs_candidate = gather_observed(nodes[candidate]["value"])
            int_candidate = gather_intervened(nodes[candidate]["value"])
            values_table[f"{candidate}_obs"] = obs_candidate.squeeze().tolist()
            values_table[f"{candidate}_int"] = int_candidate.squeeze().tolist()

        for consequent in consequents:
            obs_consequent = gather_observed(nodes[consequent]["value"])
            int_consequent = gather_intervened(nodes[consequent]["value"])
            con_lp = nodes[f"__consequent_{consequent}"]['fn'].log_prob(torch.tensor(1)) #TODO: this feels like a hack
            _indices_lp = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(con_lp)]
            int_con_lp = gather(con_lp, IndexSet(**{i: {1} for i in _indices_lp}), event_dim=0,)      


            values_table[f"{consequent}_obs"] = obs_consequent.squeeze().tolist()   
            values_table[f"{consequent}_int"] = int_consequent.squeeze().tolist()
            values_table[f"{consequent}_lp"] = int_con_lp.squeeze().tolist()   

    values_df = pd.DataFrame(values_table)
    values_df.drop_duplicates(inplace=True)
    return values_df

In [89]:
stones_table = get_table(tr)
print(stones_table.shape)
display(stones_table)

(15, 11)


Unnamed: 0,sally_throws_obs,sally_throws_int,apr_sally_throws,apr_sally_throws_lp,bill_throws_obs,bill_throws_int,bill_hits_obs,bill_hits_int,bottle_shatters_obs,bottle_shatters_int,bottle_shatters_lp
0,0.0,0.0,1,-0.510826,0.0,0.0,0.0,0.0,0.0,0.0,-100000000.0
1,0.0,0.0,1,-0.510826,0.0,0.0,0.0,0.0,0.0,0.0,-100000000.0
2,0.0,0.0,1,-0.510826,0.0,0.0,0.0,0.0,0.0,0.0,-100000000.0
3,0.0,0.0,1,-0.510826,0.0,0.0,0.0,0.0,0.0,0.0,-100000000.0
4,0.0,0.0,1,-0.510826,0.0,0.0,0.0,0.0,0.0,0.0,-100000000.0
5,0.0,0.0,1,-0.510826,0.0,0.0,0.0,0.0,0.0,0.0,-100000000.0
6,0.0,0.0,1,-0.510826,0.0,0.0,0.0,0.0,0.0,0.0,-100000000.0
7,0.0,0.0,1,-0.510826,0.0,0.0,0.0,0.0,0.0,0.0,-100000000.0
8,0.0,0.0,0,-0.916291,0.0,1.0,0.0,1.0,0.0,1.0,-100000000.0
9,0.0,0.0,0,-0.916291,0.0,0.0,0.0,0.0,0.0,0.0,0.0
