In [2]:
import math

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pytest
import torch

from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.explainable.handlers import ExtractSupports
from chirho.explainable.handlers.components import undo_split
from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers.condition import condition

In [13]:
def model_independent():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    return {"X": X, "Y": Y}

def model_connected():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    return {"X": X, "Y": Y}

with ExtractSupports() as supports_independent:
    model_independent()

with ExtractSupports() as supports_connected:
    model_connected()

with MultiWorldCounterfactual() as mwc_ind:
    with SearchForExplanation(
        supports=supports_independent.supports,
        antecedents={"X": torch.tensor(1.0)},
        consequents={"Y": torch.tensor(1.0)},
        witnesses={},
        alternatives={"X": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=3):
            with pyro.poutine.trace() as trace_independent:
                model_independent()

with MultiWorldCounterfactual() as mwc_con:
    with SearchForExplanation(
        supports=supports_connected.supports,
        antecedents={"X": torch.tensor(1.0)},
        consequents={"Y": torch.tensor(1.0)},
        witnesses={},
        alternatives={"X": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=3):
            with pyro.poutine.trace() as trace_connected:
                model_connected()

with MultiWorldCounterfactual() as mwc_rev:
    with SearchForExplanation(
        supports=supports_connected.supports,
        antecedents={"Y": torch.tensor(1.0)},
        consequents={"X": torch.tensor(1.0)},
        witnesses={},
        alternatives={"Y": torch.tensor(0.0)},
        antecedent_bias=-0.5,
        consequent_scale=0,
    ):
        with pyro.plate("sample", size=3):
            with pyro.poutine.trace() as trace_reverse:
                model_connected()

trace_connected.trace.compute_log_prob()
trace_independent.trace.compute_log_prob()
trace_reverse.trace.compute_log_prob()

Y_values_ind = trace_independent.trace.nodes["Y"]["value"]

log_probs_ind = trace_independent.trace.nodes["__cause____consequent_Y"][
    "fn"
].log_factor

with mwc_ind:
    nec_log_probs_ind = gather(log_probs_ind, IndexSet(**{"X": {1}}))
    suff_log_probs_ind = gather(log_probs_ind, IndexSet(**{"X": {2}}))

if torch.any(Y_values_ind == 1.0):
    assert nec_log_probs_ind.sum().exp() == 0.0
else:
    assert nec_log_probs_ind.sum().exp() == 1.0

assert torch.all(log_probs_ind.sum().exp() == 0)

if torch.any(Y_values_ind == 0.0):
    assert suff_log_probs_ind.sum().exp() == 0.0
else:
    assert suff_log_probs_ind.sum().exp() == 1.0

assert torch.all(
    trace_connected.trace.nodes["__cause____consequent_Y"]["fn"].log_factor.sum()
    == 0
)

log_probs_rev = trace_reverse.trace.nodes["__cause____consequent_X"]["fn"].log_factor
with mwc_rev:
    nec_log_probs_rev=gather(log_probs_rev, IndexSet(**{"Y": {1}}))
    suff_log_probs_rev=gather(log_probs_rev, IndexSet(**{"Y": {2}}))

X_values_rev = trace_reverse.trace.nodes["X"]["value"]
if torch.any(X_values_rev == 1.0):
    assert (
        nec_log_probs_rev
        .sum()
        .exp()
        == 0.0
    )
else:
    assert (
        nec_log_probs_rev
        .sum()
        .exp()
        == 1.0
    )

if torch.any(X_values_rev == 0.0):
    assert (
        suff_log_probs_rev
        .sum()
        .exp()
        == 0.0
    )
else:
    assert (
        suff_log_probs_rev
        .sum()
        .exp()
        == 1.0
    )

assert torch.all(
    log_probs_rev.sum()
    .exp()
    == 0
)

In [1]:
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import pytest
import torch

from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.counterfactual.ops import split
from chirho.explainable.handlers import random_intervention, sufficiency_intervention
from chirho.explainable.handlers.components import (  # consequent_eq_neq,
    ExtractSupports,
    consequent_eq,
    consequent_eq_neq,
    consequent_neq,
    undo_split,
)
from chirho.explainable.internals import uniform_proposal
from chirho.explainable.ops import preempt
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.interventional.handlers import do
from chirho.interventional.ops import intervene
from chirho.observational.handlers.condition import Factors

SUPPORT_CASES = [
    pyro.distributions.constraints.real,
    pyro.distributions.constraints.boolean,
    pyro.distributions.constraints.positive,
    pyro.distributions.constraints.interval(0, 10),
    pyro.distributions.constraints.interval(-5, 5),
    pyro.distributions.constraints.integer_interval(0, 2),
    pyro.distributions.constraints.integer_interval(0, 100),
]

In [3]:
# "event_shape", [(), (3,), (3, 2)]
# "plate_size", [4, 50, 200]

event_shape = ()

factors = {
    "consequent": consequent_eq_neq(
        support=constraints.independent(constraints.real, len(event_shape)),
        # proposed_consequent=torch.Tensor([0.01], event_shape),
        proposed_consequent=torch.tensor(0.01).expand(event_shape),
        antecedents=["w"],
    )
}

# w_initial = (
#     dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)).sample()
# )

# @Factors(factors=factors)
def model_ce():
    w = pyro.sample("w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)))
    consequent = pyro.deterministic("consequent", w * torch.tensor(0.1), event_dim = len(event_shape))
    assert w.shape == consequent.shape
    print(w.shape)
    print(consequent.shape)

antecedents = {
    "w": (
        torch.tensor(0.1).expand(event_shape),
        sufficiency_intervention(
            constraints.independent(constraints.real, len(event_shape)), ["w"]
        ),
    )
}

print(antecedents["w"])

with MultiWorldCounterfactual() as mwc_ce:
    with do(actions=antecedents):
        with Factors(factors=factors):
            with pyro.poutine.trace() as trace_ce:
                model_ce()

nd = trace_ce.trace.nodes
trace_ce.trace.compute_log_prob()
with mwc_ce:
    eq_neq_log_probs_fact = gather(
        nd["__factor_consequent"]["fn"].log_factor,
        IndexSet(**{"w": {0}}, event_dim=len(event_shape)),
    )
    eq_neq_log_probs_nec = gather(
        nd["__factor_consequent"]["fn"].log_factor,
        IndexSet(**{"w": {1}}, event_dim=len(event_shape)),
    )
    print("consequent_shape", indices_of(nd["consequent"]["value"].shape))
    consequent_suff = gather(
        nd["consequent"]["value"], IndexSet(**{"w": {2}}, event_dim=len(event_shape))
    )
    eq_neq_log_probs_suff = gather(
        nd["__factor_consequent"]["fn"].log_factor, IndexSet(**{"w": {2}}), event_dim=len(event_shape)
    )

    print(eq_neq_log_probs_suff.shape)
    print(eq_neq_log_probs_fact.shape)
    print(eq_neq_log_probs_suff)
    print(consequent_suff)
    print(dist.Normal(0.0, 0.1).log_prob(consequent_suff - torch.tensor(0.01)))

    print(eq_neq_log_probs_nec.shape)
    print(consequent_suff.shape)

    assert torch.equal(
        eq_neq_log_probs_fact, torch.zeros(eq_neq_log_probs_fact.shape)
    )
    assert eq_neq_log_probs_nec.shape == consequent_suff.shape
    assert torch.equal(
        eq_neq_log_probs_suff,
        dist.Normal(0.0, 0.1).log_prob(consequent_suff - torch.tensor(0.01)),
    )
    assert eq_neq_log_probs_nec.sum().exp() == 0

(tensor(0.1000), <function sufficiency_intervention.<locals>._sufficiency_intervention at 0x1290a5ab0>)
torch.Size([3, 1, 1, 1, 1])
torch.Size([3, 1, 1, 1, 1])
consequent_shape IndexSet({'w': {0, 1, 2}})
torch.Size([1, 1, 1, 1, 1])
torch.Size([1, 1, 1, 1, 1])
tensor([[[[[1.3788]]]]])
tensor([[[[[0.0001]]]]])
tensor([[[[[1.3788]]]]])
torch.Size([1, 1, 1, 1, 1])
torch.Size([1, 1, 1, 1, 1])


In [31]:
torch.tensor(0.1).expand(event_shape)

tensor([0.1000, 0.1000, 0.1000])