In [15]:
import math

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

from typing import Callable, Mapping, Optional, TypeVar, Union


from chirho.explainable.handlers.components import (
    consequent_eq_neq,
    random_intervention,
    sufficiency_intervention,
    undo_split,
)

from chirho.observational.handlers.condition import Factors
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.explainable.handlers.components import undo_split, consequent_eq_neq, sufficiency_intervention
from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets
from chirho.explainable.handlers import ExtractSupports
from chirho.observational.handlers.condition import Factors
from chirho.interventional.handlers import do
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers.condition import condition
from chirho.indexed.ops import indices_of

S = TypeVar("S")
T = TypeVar("T")

In [16]:
def model_three_independent():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    Z = pyro.sample("Z", dist.Bernoulli(0.5))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_independent:
    model_three_independent()

In [17]:
antecedent_sets = [{"X": torch.tensor(1.0), "Y": torch.tensor(1.0)}, {"X": torch.tensor(1.0), "Z": torch.tensor(1.0)}]
consequent_sets = [{"Z": torch.tensor(1.0)}, {"Y": torch.tensor(1.0)}]

tests = [("X", "Y", "Z"), ("X", "Z", "Y")]
for antecedent1, antecedent2, consequent in tests:
    print(antecedent1, antecedent2, consequent)
    with MultiWorldCounterfactual() as mwc_independent: 
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={antecedent1: torch.tensor(0.0), antecedent2: torch.tensor(0.0)},
            consequents={consequent: torch.tensor(1.0)},
            witnesses={},
            alternatives={antecedent1: torch.tensor(0.0), antecedent2: torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace_independent:
                    model_three_independent()

    trace_independent.trace.compute_log_prob
    nodes = trace_independent.trace.nodes

    log_probs = nodes[f"__cause____consequent_{consequent}"]["fn"].log_factor

    assert log_probs.shape == torch.Size([3, 3, 1, 1, 1, 1])

    nec_worlds = IndexSet(**{name : {1} for name in [antecedent1, antecedent2]})
    suff_worlds = IndexSet(**{name : {2} for name in [antecedent1, antecedent2]})

    with mwc_independent:
        nec_lp = gather(log_probs, nec_worlds)
        if nodes[consequent]["value"].item() == 1:
            assert nec_lp.exp().item() == 0.0 
        else:
            assert nec_lp.exp().item() == 1.0

        suff_lp = gather(log_probs, suff_worlds)
        if nodes[consequent]["value"].item() == 1:
            assert suff_lp.exp().item() == 1.0
        else:
            assert suff_lp.exp().item() == 0.0

    assert torch.allclose(log_probs.squeeze().fill_diagonal_(0.0), torch.tensor(0.0))

X Y Z
X Z Y
