In [1]:
import math

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

from typing import Callable, Mapping, Optional, TypeVar, Union


from chirho.explainable.handlers.components import (
    consequent_eq_neq,
    random_intervention,
    sufficiency_intervention,
    undo_split,
)

from chirho.observational.handlers.condition import Factors
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.explainable.handlers.components import undo_split, consequent_eq_neq, sufficiency_intervention
from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets
from chirho.explainable.handlers import ExtractSupports
from chirho.observational.handlers.condition import Factors
from chirho.interventional.handlers import do
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers.condition import condition
from chirho.indexed.ops import indices_of

S = TypeVar("S")
T = TypeVar("T")

In [2]:
def model_three_independent():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    Z = pyro.sample("Z", dist.Bernoulli(0.5))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_independent:
    model_three_independent()

In [34]:
with MultiWorldCounterfactual() as mwc_independent: 
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={"X": torch.tensor(1.0), "Y": torch.tensor(1.0)},
            consequents={"Z": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0), "Y": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace_independent:
                    model_three_independent()

trace_independent.trace.compute_log_prob
nodes = trace_independent.trace.nodes

log_probs = nodes["__cause____consequent_Z"]["fn"].log_factor

assert log_probs.shape == torch.Size([3, 3, 1, 1, 1, 1])

with mwc_independent:
    print(indices_of(log_probs))
    nec_lp = gather(log_probs, IndexSet(**{"X":{1}, "Y":{1}} ))
    print(nodes["Z"]["value"].item())
    if nodes["Z"]["value"].item() == 1:
        assert nec_lp.exp().item() == 0.0 
    else:
        assert nec_lp.exp().item() == 1.0

    suff_lp = gather(log_probs, IndexSet(**{"X":{2}, "Y":{2}} ))
    if nodes["Z"]["value"].item() == 1:
        assert suff_lp.exp().item() == 1.0
    else:
        assert suff_lp.exp().item() == 0.0


IndexSet({'X': {0, 1, 2}, 'Y': {0, 1, 2}})
0.0


In [35]:
with MultiWorldCounterfactual() as mwc_independent: 
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={"X": torch.tensor(1.0), "Z": torch.tensor(1.0)},
            consequents={"Y": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0), "Z": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace_independent:
                    model_three_independent()

trace_independent.trace.compute_log_prob
nodes = trace_independent.trace.nodes
assert nodes["__cause____consequent_Y"]["fn"].log_factor.shape == torch.Size([3, 3, 1, 1, 1, 1])

