In [4]:
import math

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pytest
import torch

from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.explainable.handlers import ExtractSupports
from chirho.explainable.handlers.components import undo_split
from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers.condition import condition

In [13]:
def model_independent():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    return {"X": X, "Y": Y}

def model_connected():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    return {"X": X, "Y": Y}

with ExtractSupports() as supports_independent:
    model_independent()

with ExtractSupports() as supports_connected:
    model_connected()

with MultiWorldCounterfactual() as mwc_ind:
    with SearchForExplanation(
        supports=supports_independent.supports,
        antecedents={"X": torch.tensor(1.0)},
        consequents={"Y": torch.tensor(1.0)},
        witnesses={},
        alternatives={"X": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=3):
            with pyro.poutine.trace() as trace_independent:
                model_independent()

with MultiWorldCounterfactual() as mwc_con:
    with SearchForExplanation(
        supports=supports_connected.supports,
        antecedents={"X": torch.tensor(1.0)},
        consequents={"Y": torch.tensor(1.0)},
        witnesses={},
        alternatives={"X": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=3):
            with pyro.poutine.trace() as trace_connected:
                model_connected()

with MultiWorldCounterfactual() as mwc_rev:
    with SearchForExplanation(
        supports=supports_connected.supports,
        antecedents={"Y": torch.tensor(1.0)},
        consequents={"X": torch.tensor(1.0)},
        witnesses={},
        alternatives={"Y": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=3):
            with pyro.poutine.trace() as trace_reverse:
                model_connected()

trace_connected.trace.compute_log_prob()
trace_independent.trace.compute_log_prob()
trace_reverse.trace.compute_log_prob()

Y_values_ind = trace_independent.trace.nodes["Y"]["value"]

log_probs_ind = trace_independent.trace.nodes["__cause____consequent_Y"][
    "fn"
].log_factor

with mwc_ind:
    nec_log_probs_ind = gather(log_probs_ind, IndexSet(**{"X": {1}}))
    suff_log_probs_ind = gather(log_probs_ind, IndexSet(**{"X": {2}}))

if torch.any(Y_values_ind == 1.0):
    assert nec_log_probs_ind.sum().exp() == 0.0
else:
    assert nec_log_probs_ind.sum().exp() == 1.0

assert torch.all(log_probs_ind.sum().exp() == 0)

if torch.any(Y_values_ind == 0.0):
    assert suff_log_probs_ind.sum().exp() == 0.0
else:
    assert suff_log_probs_ind.sum().exp() == 1.0

assert torch.all(
    trace_connected.trace.nodes["__cause____consequent_Y"]["fn"].log_factor.sum()
    == 0
)

log_probs_rev = trace_reverse.trace.nodes["__cause____consequent_X"]["fn"].log_factor
with mwc_rev:
    nec_log_probs_rev=gather(log_probs_rev, IndexSet(**{"Y": {1}}))
    suff_log_probs_rev=gather(log_probs_rev, IndexSet(**{"Y": {2}}))

X_values_rev = trace_reverse.trace.nodes["X"]["value"]
if torch.any(X_values_rev == 1.0):
    assert (
        nec_log_probs_rev
        .sum()
        .exp()
        == 0.0
    )
else:
    assert (
        nec_log_probs_rev
        .sum()
        .exp()
        == 1.0
    )

if torch.any(X_values_rev == 0.0):
    assert (
        suff_log_probs_rev
        .sum()
        .exp()
        == 0.0
    )
else:
    assert (
        suff_log_probs_rev
        .sum()
        .exp()
        == 1.0
    )

assert torch.all(
    log_probs_rev.sum()
    .exp()
    == 0
)

In [2]:
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pytest
import torch

from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.counterfactual.ops import split
from chirho.explainable.handlers import random_intervention, sufficiency_intervention
from chirho.explainable.handlers.components import (  # consequent_eq_neq,
    ExtractSupports,
    consequent_eq,
    consequent_eq_neq,
    consequent_neq,
    undo_split,
)
from chirho.explainable.internals import uniform_proposal
from chirho.explainable.ops import preempt
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.interventional.handlers import do
from chirho.interventional.ops import intervene
from chirho.observational.handlers.condition import Factors

SUPPORT_CASES = [
    pyro.distributions.constraints.real,
    pyro.distributions.constraints.boolean,
    pyro.distributions.constraints.positive,
    pyro.distributions.constraints.interval(0, 10),
    pyro.distributions.constraints.interval(-5, 5),
    pyro.distributions.constraints.integer_interval(0, 2),
    pyro.distributions.constraints.integer_interval(0, 100),
]

In [54]:
# "event_shape", [(), (3,), (3, 2)]
# "plate_size", [4, 50, 200]

plate_size = 200
event_shape = (3,2)

factors = {
    "consequent": consequent_eq_neq(
        support=constraints.independent(constraints.real, len(event_shape)),
        # proposed_consequent=torch.Tensor([0.01], event_shape),
        proposed_consequent=torch.tensor(0.01).expand(event_shape),
        antecedents=["w"],
    )
}

# w_initial = (
#     dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)).sample()
# )

@Factors(factors=factors)
@pyro.plate("data", size=plate_size, dim=-4)
def model_ce():
    w = pyro.sample("w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)))
    consequent = pyro.deterministic("consequent", w * torch.tensor(0.1), event_dim = len(event_shape))
    assert w.shape == consequent.shape
    print(w.shape)
    print(consequent.shape)

antecedents = {
    "w": (
        torch.tensor(0.1).expand(event_shape),
        sufficiency_intervention(
            constraints.independent(constraints.real, len(event_shape)), ["w"]
        ),
    )
}

print(antecedents["w"])

with MultiWorldCounterfactual() as mwc_ce:
    with do(actions=antecedents):
        with pyro.poutine.trace() as trace_ce:
            model_ce()

trace_ce.trace.compute_log_prob()
nd = trace_ce.trace.nodes
with mwc_ce:
    eq_neq_log_probs_fact = gather(
        nd["__factor_consequent"]["fn"].log_factor,
        IndexSet(**{"w": {0}})
    )

    print(nd["__factor_consequent"]["fn"].log_factor.shape)
    print(indices_of(nd["__factor_consequent"]["fn"].log_factor))
    print(eq_neq_log_probs_fact.shape)

    eq_neq_log_probs_nec = gather(
        nd["__factor_consequent"]["fn"].log_factor,
        IndexSet(**{"w": {1}})
    )
    # print("consequent_shape", indices_of(nd["consequent"]["value"].shape, event_dim=len(event_shape)))
    consequent_suff = gather(
        nd["consequent"]["value"], IndexSet(**{"w": {2}}), event_dim=len(event_shape)
    )
    eq_neq_log_probs_suff = gather(
        nd["__factor_consequent"]["fn"].log_factor, IndexSet(**{"w": {2}})
    )

    # print(eq_neq_log_probs_suff.shape)
    # print(eq_neq_log_probs_fact.shape)
    print(eq_neq_log_probs_suff.squeeze())
    # print(consequent_suff)
    # print(dist.Normal(0.0, 0.1).log_prob(consequent_suff - torch.tensor(0.01)))

    print(eq_neq_log_probs_nec)
    # print(consequent_suff.shape)

    assert torch.equal(
        eq_neq_log_probs_fact, torch.zeros(eq_neq_log_probs_fact.shape)
    )

    print(dist.Normal(0.0, 0.1).log_prob(consequent_suff - torch.tensor(0.01)).sum().squeeze())
    # assert eq_neq_log_probs_nec.shape == consequent_suff.shape

    result = dist.Normal(0.0, 0.1).log_prob(consequent_suff - torch.tensor(0.01))
    for _ in range(len(event_shape)):
        result = torch.sum(result, dim=-1)

    print(result)

    assert torch.allclose(
        eq_neq_log_probs_suff.squeeze(),
        result.squeeze(),
    )
    assert eq_neq_log_probs_nec.sum().exp().item() == 0

(tensor([[0.1000, 0.1000],
        [0.1000, 0.1000],
        [0.1000, 0.1000]]), <function sufficiency_intervention.<locals>._sufficiency_intervention at 0x117471090>)
torch.Size([3, 200, 1, 1, 1, 3, 2])
torch.Size([3, 200, 1, 1, 1, 3, 2])
torch.Size([3, 200, 1, 1, 1])
IndexSet({'w': {0, 1, 2}})
torch.Size([1, 200, 1, 1, 1])
tensor([8.2223, 8.2605, 8.2698, 8.2651, 8.2059, 8.2255, 8.2357, 8.2512, 8.2502,
        8.2338, 8.2549, 8.2659, 8.2067, 8.2874, 8.2593, 8.2138, 8.2624, 8.2358,
        8.2504, 8.2528, 8.2693, 8.2498, 8.2330, 8.2527, 8.2848, 8.2551, 8.2632,
        8.2444, 8.2634, 8.2679, 8.2186, 8.2608, 8.2668, 8.2299, 8.2255, 8.2452,
        8.2578, 8.2228, 8.2077, 8.2434, 8.2753, 8.2417, 8.2302, 8.2745, 8.2561,
        8.2670, 8.2126, 8.1921, 8.2064, 8.2119, 8.2018, 8.2624, 8.2835, 8.2358,
        8.2572, 8.2236, 8.2636, 8.2805, 8.2737, 8.2429, 8.2934, 8.2900, 8.2351,
        8.2445, 8.2874, 8.2032, 8.2813, 8.2572, 8.2831, 8.1908, 8.2640, 8.2324,
        8.2388, 8.2196, 8.1854, 8

In [31]:
torch.tensor(0.1).expand(event_shape)

tensor([0.1000, 0.1000, 0.1000])

In [14]:
def model_independent(event_shape):
    X = pyro.sample("X", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Y = pyro.sample("Y", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    return {"X": X, "Y": Y}


def model_connected(event_shape):
    X = pyro.sample("X", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Y = pyro.deterministic("Y", X, event_dim = len(event_shape))
    return {"X": X, "Y": Y}


# @pytest.mark.parametrize("ante_cons", [("Y", "X")])
# @pytest.mark.parametrize(
#     "model",
#     [
#         model_independent,
#         model_connected
#     ],
# )
# @pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str)
def test_edge_eq_neq(model, ante_cons, event_shape):
    with ExtractSupports() as supports:
        model(event_shape)

    antecedent = ante_cons[0]
    consequent = ante_cons[1]

    with MultiWorldCounterfactual() as mwc:
        with SearchForExplanation(
            supports=supports.supports,
            antecedents={antecedent: torch.tensor(1.0).expand(event_shape)},
            consequents={consequent: torch.tensor(1.0).expand(event_shape)},
            witnesses={},
            alternatives={antecedent: torch.tensor(0.0).expand(event_shape)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=3):
                with pyro.poutine.trace() as trace:
                    model(event_shape)

    trace.trace.compute_log_prob()

    cons_values = trace.trace.nodes[consequent]["value"]

    log_probs = trace.trace.nodes[f"__cause____consequent_{consequent}"][
        "fn"
    ].log_factor

    with mwc:
        nec_log_probs = gather(log_probs, IndexSet(**{antecedent: {1}}))
        suff_log_probs = gather(log_probs, IndexSet(**{antecedent: {2}}))

    if torch.any(cons_values == 1.0):
        assert nec_log_probs.sum().exp() == 0.0
    else:
        assert nec_log_probs.sum().exp() == 1.0

    assert torch.all(log_probs.sum().exp() == 0)

    if torch.any(cons_values == 0.0):
        assert suff_log_probs.sum().exp() == 0.0
    else:
        assert suff_log_probs.sum().exp() == 1.0

    assert torch.all(
        trace.trace.nodes[f"__cause____consequent_{consequent}"]["fn"].log_factor.sum().exp()
        == 0
    )

test_edge_eq_neq(model_connected, ("Y", "X"), (3,2))

In [93]:
def model_three_converge(event_shape):
    X = pyro.sample("X", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Y = pyro.sample("Y", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Z = pyro.deterministic("Z", torch.min(X, Y), event_dim=len(event_shape))
    return {"X": X, "Y": Y, "Z": Z}


# X -> Y, X -> Z
def model_three_diverge(event_shape):
    X = pyro.sample("X", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Y = pyro.deterministic("Y", X, event_dim=len(event_shape))
    Z = pyro.deterministic("Z", X, event_dim=len(event_shape))
    return {"X": X, "Y": Y, "Z": Z}


# X -> Y -> Z
def model_three_chain(event_shape):
    X = pyro.sample("X", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Y = pyro.deterministic("Y", X, event_dim=len(event_shape))
    Z = pyro.deterministic("Z", Y, event_dim=len(event_shape))
    return {"X": X, "Y": Y, "Z": Z}


# X -> Y, X -> Z, Y -> Z
def model_three_complete(event_shape):
    X = pyro.sample("X", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Y = pyro.deterministic("Y", X, event_dim=len(event_shape))
    Z = pyro.deterministic("Z", torch.max(X, Y), event_dim=len(event_shape))
    return {"X": X, "Y": Y, "Z": Z}


# X -> Y    Z
def model_three_isolate(event_shape):
    X = pyro.sample("X", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Y = pyro.deterministic("Y", X, event_dim=len(event_shape))
    Z = pyro.sample("Z", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    return {"X": X, "Y": Y, "Z": Z}


# X     Y    Z
def model_three_independent(event_shape):
    X = pyro.sample("X", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Y = pyro.sample("Y", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    Z = pyro.sample("Z", dist.Bernoulli(0.5).expand(event_shape).to_event(len(event_shape)))
    return {"X": X, "Y": Y, "Z": Z}


@pytest.mark.parametrize("ante_cons", [("X", "Y", "Z"), ("X", "Z", "Y")])
@pytest.mark.parametrize(
    "model",
    [
        model_three_converge,
        model_three_diverge,
        model_three_chain,
        # model_three_complete,
        # model_three_isolate,
        # model_three_independent,
    ],
)
@pytest.mark.parametrize("event_shape", [(), (3,), (3, 2)], ids=str)
def test_eq_neq_three_variables(model, ante_cons, event_shape):
    ante1, ante2, cons = ante_cons
    with ExtractSupports() as supports:
        model(event_shape)
        for var, sup in supports.supports.items():
            if isinstance(sup, constraints.independent):
                sup.base_constraint = constraints.boolean
            else:
                sup.base_constraint = constraints.boolean
        

    with MultiWorldCounterfactual() as mwc:
        with SearchForExplanation(
            supports=supports.supports,
            antecedents={ante1: torch.tensor(1.0).expand(event_shape), ante2: torch.tensor(1.0).expand(event_shape)},
            consequents={cons: torch.tensor(1.0).expand(event_shape)},
            witnesses={},
            alternatives={ante1: torch.tensor(0.0).expand(event_shape), ante2: torch.tensor(0.0).expand(event_shape)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace:
                    model(event_shape)

    trace.trace.compute_log_prob()
    nodes = trace.trace.nodes

    values = nodes[cons]["value"]
    log_probs = nodes[f"__cause____consequent_{cons}"]["fn"].log_factor

    fact_worlds = IndexSet(**{name: {0} for name in [ante1, ante2]})
    nec_worlds = IndexSet(**{name: {1} for name in [ante1, ante2]})
    suff_worlds = IndexSet(**{name: {2} for name in [ante1, ante2]})
    with mwc:
        assert indices_of(log_probs) == {ante1: {0, 1, 2}, ante2: {0, 1, 2}}

        fact_lp = gather(log_probs, fact_worlds)
        fact_value = gather(values, fact_worlds, event_dim=len(event_shape))
        assert fact_lp.exp().item() == 1

        nec_value = gather(values, nec_worlds, event_dim=len(event_shape))
        nec_lp = gather(log_probs, nec_worlds)

        if torch.equal(nec_value, fact_value) & (not torch.allclose(nec_value, torch.tensor(0.0))):
            assert nec_lp.exp().item() == 0.0
        elif torch.allclose(nec_value, torch.tensor(0.0)):
            assert nec_lp.exp().item() == 1.0

        suff_value = gather(values, suff_worlds, event_dim=len(event_shape))
        suff_lp = gather(log_probs, suff_worlds)

        if torch.equal(suff_value, fact_value) & (not torch.allclose(suff_value, torch.tensor(1.0))):
            assert suff_lp.exp().item() == 0.0
        elif torch.allclose(suff_value, torch.tensor(1.0)):
            assert suff_lp.exp().item() == 1.0

    assert torch.allclose(log_probs.squeeze().fill_diagonal_(0.0), torch.tensor(0.0))

test_eq_neq_three_variables(model_three_converge, ("X", "Z", "Y"), (3,2))