In [1]:
import math

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.explainable.handlers.components import undo_split
from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets
from chirho.explainable.handlers import ExtractSupports
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers.condition import condition

In [2]:
# def test_edge_eq_neq():

def model_independent():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))

def model_connected():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))



with ExtractSupports() as supports_independent:
    model_independent()

with ExtractSupports() as supports_connected:
    model_connected()

with MultiWorldCounterfactual() as mwc_independent:  
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={"X": torch.tensor(1.0)},
            consequents={"Y": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=3):
                with pyro.poutine.trace() as trace_independent:
                    model_independent()

with MultiWorldCounterfactual() as mwc_connected:  
        with SearchForExplanation(
            supports=supports_connected.supports,
            antecedents={"X": torch.tensor(1.0)},
            consequents={"Y": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=3):
                with pyro.poutine.trace() as trace_connected:
                    model_connected()

with MultiWorldCounterfactual() as mwc_reverse:  
        with SearchForExplanation(
            supports=supports_connected.supports,
            antecedents={"Y": torch.tensor(1.0)},
            consequents={"X": torch.tensor(1.0)},
            witnesses={},
            alternatives={"Y": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=3):
                with pyro.poutine.trace() as trace_reverse:
                    model_connected()


trace_connected.trace.compute_log_prob
trace_independent.trace.compute_log_prob
trace_reverse.trace.compute_log_prob

#print(trace_independent.trace.nodes["__cause____consequent_Y"]["fn"].log_factor)
#print(trace_independent.trace.nodes["Y"]["value"])

Y_values_ind = trace_independent.trace.nodes["Y"]["value"]
Y_values_con = trace_connected.trace.nodes["Y"]["value"]
X_values_rev = trace_reverse.trace.nodes["X"]["value"]


if torch.any(Y_values_ind == 1.):
    print("testing with ", Y_values_ind)
    assert trace_independent.trace.nodes["__cause____consequent_Y"]["fn"].log_factor[1,0,0,0,:].sum().exp() == 0.
else:
    assert trace_independent.trace.nodes["__cause____consequent_Y"]["fn"].log_factor[1,0,0,0,:].sum().exp() == 1.
    

if torch.any(Y_values_ind == 0.):
    assert trace_independent.trace.nodes["__cause____consequent_Y"]["fn"].log_factor[2,0,0,0,:].sum().exp() == 0.
else:
    assert trace_independent.trace.nodes["__cause____consequent_Y"]["fn"].log_factor[2,0,0,0,:].sum().exp() == 1.

assert torch.all(trace_connected.trace.nodes["__cause____consequent_Y"]["fn"].log_factor.sum() == 0)
     

print(X_values_rev)
print(trace_reverse.trace.nodes["__cause____consequent_X"]["fn"].log_factor)



# assert torch.all(trace_connected.trace.nodes["__cause____consequent_Y"]["fn"].log_factor[0,0,0,0,:] == 0)

consequent tensor([0., 0., 0.])
consequent tensor([[[[[1., 1., 0.]]]],



        [[[[0., 0., 0.]]]],



        [[[[1., 1., 1.]]]]])
consequent tensor([0., 1., 0.])
tensor([0., 1., 0.])
tensor([[[[[-inf, 0., -inf]]]]])


In [4]:
with MultiWorldCounterfactual() as mwc_connected:  
        with SearchForExplanation(
            supports=supports_connected.supports,
            antecedents={"X": torch.tensor(1.0)},
            consequents={"Y": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=3):
                with pyro.poutine.trace() as trace_connected:
                    model_connected()

consequent tensor([[[[[1., 1., 1.]]]],



        [[[[0., 0., 0.]]]],



        [[[[1., 1., 1.]]]]])


KeyboardInterrupt: 

In [6]:
with MultiWorldCounterfactual() as mwc_reverse:  
        with SearchForExplanation(
            supports=supports_connected.supports,
            antecedents={"Y": torch.tensor(1.0)},
            consequents={"X": torch.tensor(1.0)},
            witnesses={},
            alternatives={"Y": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=3):
                with pyro.poutine.trace() as trace_reverse:
                    model_connected()
                    

consequent tensor([1., 0., 1.])
