In [26]:
import math

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

from typing import Callable, Mapping, Optional, TypeVar, Union


from chirho.explainable.handlers.components import (
    consequent_eq_neq,
    random_intervention,
    sufficiency_intervention,
    undo_split,
)

from chirho.observational.handlers.condition import Factors
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.explainable.handlers.components import undo_split, consequent_eq_neq, sufficiency_intervention
from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets
from chirho.explainable.handlers import ExtractSupports
from chirho.observational.handlers.condition import Factors
from chirho.interventional.handlers import do
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers.condition import condition
from chirho.indexed.ops import indices_of

S = TypeVar("S")
T = TypeVar("T")

In [2]:
# X -> Z, Y -> Z

def model_three_converge():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    Z = pyro.sample("Z", dist.Bernoulli(torch.min(X, Y)))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_converge:
    model_three_converge()

In [3]:
with MultiWorldCounterfactual() as mwc_converge: 
    with SearchForExplanation(
        supports=supports_converge.supports,
        antecedents={"Y": torch.tensor(1.0), "X": torch.tensor(1.0)},
        consequents={"Z": torch.tensor(1.0)},
        witnesses={},
        alternatives={"Y": torch.tensor(1.0), "X": torch.tensor(1.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=1):
            with pyro.poutine.trace() as trace_converge:
                model_three_converge()

trace_converge.trace.compute_log_prob
nodes = trace_converge.trace.nodes

values = nodes["Z"]["value"]
log_probs = nodes["__cause____consequent_Z"]["fn"].log_factor
assert values.shape == log_probs.shape

nec_worlds = IndexSet(**{name : {1} for name in ["X", "Y"]})
suff_worlds = IndexSet(**{name : {2} for name in ["X", "Y"]})

with mwc_converge:
    nec_value = gather(values, nec_worlds)
    nec_lp = gather(log_probs, nec_worlds)
    assert nec_lp.exp().item() == 1 - nec_value.item()

    suff_value = gather(values, suff_worlds)
    suff_lp = gather(log_probs, suff_worlds)
    assert suff_lp.exp().item() == suff_value.item()

assert torch.allclose(log_probs.squeeze().fill_diagonal_(0.0), torch.tensor(0.0))

In [4]:
# X -> Y, X -> Z

def model_three_diverge():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    Z = pyro.sample("Z", dist.Bernoulli(X))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_diverge:
    model_three_diverge()

In [16]:
with MultiWorldCounterfactual() as mwc_diverge: 
    with SearchForExplanation(
        supports=supports_diverge.supports,
        antecedents={"Y": torch.tensor(1.0), "X": torch.tensor(1.0)},
        consequents={"Z": torch.tensor(1.0)},
        witnesses={},
        alternatives={"Y": torch.tensor(0.0), "X": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=1):
            with pyro.poutine.trace() as trace_diverge:
                model_three_diverge()

trace_diverge.trace.compute_log_prob
nodes = trace_diverge.trace.nodes

values = nodes["Z"]["value"]
log_probs = nodes["__cause____consequent_Z"]["fn"].log_factor

assert log_probs.shape == torch.Size([3, 3, 1, 1, 1, 1])

nec_worlds = IndexSet(**{name : {1} for name in ["X", "Y"]})
suff_worlds = IndexSet(**{name : {2} for name in ["X", "Y"]})

with mwc_diverge:
    nec_value = gather(values, nec_worlds)
    nec_lp = gather(log_probs, nec_worlds)
    assert nec_lp.exp().item() == 1 - nec_value.item()

    suff_value = gather(values, suff_worlds)
    suff_lp = gather(log_probs, suff_worlds)
    assert suff_lp.exp().item() == suff_value.item()

assert torch.allclose(log_probs.squeeze().fill_diagonal_(0.0), torch.tensor(0.0))

In [29]:
# X -> Y -> Z

def model_three_chain():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    Z = pyro.sample("Z", dist.Bernoulli(Y))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_chain:
    model_three_chain()

In [30]:
with MultiWorldCounterfactual() as mwc_chain: 
    with SearchForExplanation(
        supports=supports_chain.supports,
        antecedents={"X": torch.tensor(1.0), "Z": torch.tensor(1.0)},
        consequents={"Y": torch.tensor(1.0)},
        witnesses={},
        alternatives={"X": torch.tensor(0.0), "Z": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=1):
            with pyro.poutine.trace() as trace_chain:
                model_three_chain()

trace_chain.trace.compute_log_prob
nodes = trace_chain.trace.nodes

values = nodes["Y"]["value"]
log_probs = nodes["__cause____consequent_Y"]["fn"].log_factor

print(values.squeeze())
print(log_probs.squeeze())

assert log_probs.shape == torch.Size([3, 3, 1, 1, 1, 1])

nec_worlds = IndexSet(**{name : {1} for name in ["X", "Z"]})
suff_worlds = IndexSet(**{name : {2} for name in ["X", "Z"]})

with mwc_chain:
    nec_value = gather(values, nec_worlds)
    nec_lp = gather(log_probs, nec_worlds)
    assert nec_lp.exp().item() == 1 - nec_value.item()

    suff_value = gather(values, suff_worlds)
    suff_lp = gather(log_probs, suff_worlds)
    assert suff_lp.exp().item() == suff_value.item()

assert torch.allclose(log_probs.squeeze().fill_diagonal_(0.0), torch.tensor(0.0))

tensor([0., 0., 1.])
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
tensor([[[[[0.]]]]]) tensor([[[[[[0.]]]]]])


In [36]:
# X -> Y -> Z X -> Z

def model_three_complete():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    Z = pyro.sample("Z", dist.Bernoulli(torch.max(X, Y)))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_complete:
    model_three_complete()

In [37]:
with MultiWorldCounterfactual() as mwc_complete: 
    with SearchForExplanation(
        supports=supports_complete.supports,
        antecedents={"X": torch.tensor(1.0), "Y": torch.tensor(1.0)},
        consequents={"Z": torch.tensor(1.0)},
        witnesses={},
        alternatives={"X": torch.tensor(0.0), "Y": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=1):
            with pyro.poutine.trace() as trace_complete:
                model_three_complete()

trace_complete.trace.compute_log_prob
nodes = trace_complete.trace.nodes

values = nodes["Z"]["value"]
log_probs = nodes["__cause____consequent_Z"]["fn"].log_factor

print(values.squeeze())
print(log_probs.squeeze())

assert log_probs.shape == torch.Size([3, 3, 1, 1, 1, 1])

nec_worlds = IndexSet(**{name : {1} for name in ["X", "Y"]})
suff_worlds = IndexSet(**{name : {2} for name in ["X", "Y"]})

with mwc_complete:
    nec_value = gather(values, nec_worlds)
    nec_lp = gather(log_probs, nec_worlds)
    assert nec_lp.exp().item() == 1 - nec_value.item()

    suff_value = gather(values, suff_worlds)
    suff_lp = gather(log_probs, suff_worlds)
    assert suff_lp.exp().item() == suff_value.item()

assert torch.allclose(log_probs.squeeze().fill_diagonal_(0.0), torch.tensor(0.0))

tensor([[1., 0., 1.],
        [1., 0., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])


In [39]:
# X -> Y    Z

def model_three_isolate():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    Z = pyro.sample("Z", dist.Bernoulli(0.5))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_isolate:
    model_three_isolate()

In [46]:
with MultiWorldCounterfactual() as mwc_isolate: 
    with SearchForExplanation(
        supports=supports_isolate.supports,
        antecedents={"X": torch.tensor(1.0), "Z": torch.tensor(1.0)},
        consequents={"Y": torch.tensor(1.0)},
        witnesses={},
        alternatives={"X": torch.tensor(0.0), "Z": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=1):
            with pyro.poutine.trace() as trace_isolate:
                model_three_complete()

trace_isolate.trace.compute_log_prob
nodes = trace_isolate.trace.nodes

values = nodes["Y"]["value"]
log_probs = nodes["__cause____consequent_Y"]["fn"].log_factor

print(values.squeeze())
print(log_probs.squeeze())

assert log_probs.shape == torch.Size([3, 3, 1, 1, 1, 1])

nec_worlds = IndexSet(**{name : {1} for name in ["X", "Z"]})
suff_worlds = IndexSet(**{name : {2} for name in ["X", "Z"]})

with mwc_isolate:
    nec_value = gather(values, nec_worlds)
    nec_lp = gather(log_probs, nec_worlds)
    assert nec_lp.exp().item() == 1 - nec_value.item()

    suff_value = gather(values, suff_worlds)
    suff_lp = gather(log_probs, suff_worlds)
    assert suff_lp.exp().item() == suff_value.item()

assert torch.allclose(log_probs.squeeze().fill_diagonal_(0.0), torch.tensor(0.0))

tensor([1., 0., 1.])
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])


In [47]:
import pytest

def model_three_converge():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    Z = pyro.sample("Z", dist.Bernoulli(torch.min(X, Y)))
    return {"X": X, "Y": Y, "Z": Z}

def model_three_diverge():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    Z = pyro.sample("Z", dist.Bernoulli(X))
    return {"X": X, "Y": Y, "Z": Z}


In [48]:
@pytest.mark.parametrize("model", [model_three_converge, model_three_diverge])
def test_three_variables(model):
    with ExtractSupports() as supports:
        model()

    with MultiWorldCounterfactual() as mwc: 
        with SearchForExplanation(
            supports=supports.supports,
            antecedents={"X": torch.tensor(1.0), "Z": torch.tensor(1.0)},
            consequents={"Y": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0), "Z": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace:
                    model()

    trace.trace.compute_log_prob
    nodes = trace.trace.nodes

    values = nodes["Y"]["value"]
    log_probs = nodes["__cause____consequent_Y"]["fn"].log_factor

    print(values.squeeze())
    print(log_probs.squeeze())

    assert log_probs.shape == torch.Size([3, 3, 1, 1, 1, 1])

    nec_worlds = IndexSet(**{name : {1} for name in ["X", "Z"]})
    suff_worlds = IndexSet(**{name : {2} for name in ["X", "Z"]})

    with mwc:
        nec_value = gather(values, nec_worlds)
        nec_lp = gather(log_probs, nec_worlds)
        assert nec_lp.exp().item() == 1 - nec_value.item()

        suff_value = gather(values, suff_worlds)
        suff_lp = gather(log_probs, suff_worlds)
        assert suff_lp.exp().item() == suff_value.item()

    assert torch.allclose(log_probs.squeeze().fill_diagonal_(0.0), torch.tensor(0.0))

    


In [49]:
test_three_variables(model_three_diverge)

tensor([0., 0., 1.])
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
