In [5]:
import math

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

from typing import Callable, Mapping, Optional, TypeVar, Union


from chirho.explainable.handlers.components import (
    consequent_eq_neq,
    random_intervention,
    sufficiency_intervention,
    undo_split,
)

from chirho.observational.handlers.condition import Factors
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.explainable.handlers.components import undo_split, consequent_eq_neq, sufficiency_intervention
from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets
from chirho.explainable.handlers import ExtractSupports
from chirho.observational.handlers.condition import Factors
from chirho.interventional.handlers import do
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers.condition import condition
from chirho.indexed.ops import indices_of

S = TypeVar("S")
T = TypeVar("T")

In [4]:
def model_three_independent():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    Z = pyro.sample("Z", dist.Bernoulli(0.5))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_independent:
    model_three_independent()

with MultiWorldCounterfactual() as mwc_independent: 
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={"X": torch.tensor(1.0), "Y": torch.tensor(1.0)},
            consequents={"Z": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0), "Y": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace_independent:
                    model_three_independent()

trace_independent.trace.compute_log_prob
nodes = trace_independent.trace.nodes
# assert nodes["__cause____consequent_Z"]["fn"].log_factor.shape == torch.Size([3, 3, 1, 1, 1, 1])

print(trace_independent.trace.nodes["X"]["value"].squeeze())
print(trace_independent.trace.nodes["Y"]["value"].squeeze())
print(trace_independent.trace.nodes["Z"]["value"].squeeze())
print(trace_independent.trace.nodes["__cause____consequent_Z"]["fn"].log_factor.squeeze())

necessity_log_probs tensor([0.])
sufficiency_log_probs tensor([-inf])
IndexSet({'X': {0}, 'Y': {0}})
nec_suff_log_prob_partitioned {IndexSet({'X': {0}}): tensor([0.]), IndexSet({'Y': {0}}): tensor([0.]), IndexSet({'X': {1}}): tensor([0.]), IndexSet({'X': {2}}): tensor([-inf]), IndexSet({'Y': {1}}): tensor([0.]), IndexSet({'Y': {2}}): tensor([-inf])}
new_value tensor([[[[[[-inf]]]],



         [[[[-inf]]]],



         [[[[-inf]]]]]])
tensor([1., 0., 1.])
tensor([1., 0., 1.])
tensor(0.)
tensor([-inf, -inf, -inf])


In [8]:
#   X -> Y, X -> Z

def model_three_diverge():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    Z = pyro.sample("Z", dist.Bernoulli(torch.min(X, Y)))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_diverge:
    model_three_diverge()

with MultiWorldCounterfactual() as mwc_diverge: 
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={"Y": torch.tensor(1.0), "X": torch.tensor(1.0)},
            consequents={"Z": torch.tensor(1.0)},
            witnesses={},
            alternatives={"Y": torch.tensor(1.0), "X": torch.tensor(1.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace_diverge:
                    model_three_diverge()

trace_diverge.trace.compute_log_prob
nodes = trace_diverge.trace.nodes
print(nodes["__cause____consequent_Z"]["fn"].log_factor)
assert nodes["__cause____consequent_Z"]["fn"].log_factor.shape == torch.Size([3, 3, 1, 1, 1, 1])

tensor([[[[[[0.]]]],



         [[[[-inf]]]],



         [[[[0.]]]]],




        [[[[[-inf]]]],



         [[[[-inf]]]],



         [[[[0.]]]]],




        [[[[[0.]]]],



         [[[[-inf]]]],



         [[[[0.]]]]]])


In [None]:
# X -> Y -> Z X -> Z

def model_three_complete():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    Z = pyro.sample("Z", dist.Bernoulli(max(X, Y)))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_complete:
    model_three_complete()

In [None]:
# X -> Y    Z

def model_three_isolate():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    Z = pyro.sample("Z", dist.Bernoulli(0.5))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_isolate:
    model_three_isolate()