In [1]:
import math

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

from typing import Callable, Mapping, Optional, TypeVar, Union


from chirho.explainable.handlers.components import (
    consequent_eq_neq,
    random_intervention,
    sufficiency_intervention,
    undo_split,
)

from chirho.observational.handlers.condition import Factors
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.explainable.handlers.components import undo_split, consequent_eq_neq, sufficiency_intervention
from chirho.explainable.handlers.explanation import SearchForExplanation, SplitSubsets
from chirho.explainable.handlers import ExtractSupports
from chirho.observational.handlers.condition import Factors
from chirho.interventional.handlers import do
from chirho.explainable.handlers.preemptions import Preemptions
from chirho.indexed.ops import IndexSet, gather
from chirho.observational.handlers.condition import condition
from chirho.indexed.ops import indices_of

S = TypeVar("S")
T = TypeVar("T")

In [2]:
# def test_edge_eq_neq():

def model_three_independent():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    Z = pyro.sample("Z", dist.Bernoulli(0.5))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_independent:
    model_three_independent()
    
antecedents = {"X": (torch.tensor(0.0), torch.tensor(1.0)),
              "Y": (torch.tensor(0.0), torch.tensor(1.0))}

with MultiWorldCounterfactual() as m_ind_do:
    with do(actions = antecedents):
        with pyro.poutine.trace() as tr_do:
            model_three_independent()
        
nodes_do = tr_do.trace.nodes

with m_ind_do:
    print("factual", get_factual_indices())
    print(indices_of(nodes_do["X"]["value"]))
    print(indices_of(nodes_do["Y"]["value"]))
    print(indices_of(nodes_do["Z"]["value"]))

factual IndexSet({'X': {0}, 'Y': {0}})
IndexSet({'X': {0, 1, 2}})
IndexSet({'Y': {0, 1, 2}})
IndexSet({})


In [6]:
antecedents={"X": torch.tensor(1.0), "Y": torch.tensor(1.0)}
alternatives={"X": torch.tensor(0.0), "Y": torch.tensor(0.0)}
supports=supports_independent.supports
antecedent_bias = 0.0 #-0.5
prefix = "__cause__"
consequents={"Z": torch.tensor(1.0)}
consequent_scale=0
factors = None
witnesses = {}
preemptions = None
witness_bias = 0.0



alternatives = (
        {a: alternatives[a] for a in antecedents.keys()}
        if alternatives is not None
        else {
            a: random_intervention(supports[a], name=f"{prefix}_alternative_{a}")
            for a in antecedents.keys()
        }
    )

with MultiWorldCounterfactual() as m_ind_do:
    with do(actions = alternatives):
        with pyro.poutine.trace() as tr_do:
            model_three_independent()
        
nodes_do = tr_do.trace.nodes

with m_ind_do:
    print("factual", get_factual_indices())
    print(indices_of(nodes_do["X"]["value"]))
    print(indices_of(nodes_do["Y"]["value"]))
    print(indices_of(nodes_do["Z"]["value"]))

factual IndexSet({'X': {0}, 'Y': {0}})
IndexSet({'X': {0, 1}})
IndexSet({'Y': {0, 1}})
IndexSet({})


In [38]:
print(consequents)

sufficiency_actions = {
    a: (
        antecedents[a]
        if antecedents[a] is not None
        else sufficiency_intervention(supports[a], antecedents=antecedents.keys())
    )
    for a in antecedents.keys()
}

antecedent_handler = SplitSubsets(
    {a: supports[a] for a in antecedents.keys()},
    {a: (alternatives[a], sufficiency_actions[a]) for a in antecedents.keys()},  # type: ignore
    bias=antecedent_bias,
    prefix=f"{prefix}__antecedent_",
)


witness_handler = Preemptions(
        (
            {w: preemptions[w] for w in witnesses}
            if preemptions is not None
            else {
                w: undo_split(supports[w], antecedents=antecedents.keys())
                for w in witnesses
            }
        ),
        bias=witness_bias,
        prefix=f"{prefix}__witness_",
    )


consequent_handler: Factors[T] = Factors(
        (
            {c: factors[c] for c in consequents.keys()}
            if factors is not None
            else {
                c: consequent_eq_neq(
                    support=supports[c],
                    proposed_consequent=consequents[c],  # added this
                    antecedents=antecedents.keys(),
                    scale=consequent_scale,
                )
                for c in consequents.keys()
            }
        ),
        prefix=f"{prefix}__consequent_",
    )


with MultiWorldCounterfactual() as m_ind_do:
    with antecedent_handler, witness_handler, consequent_handler:
        with pyro.poutine.trace() as tr_do:
            model_three_independent()
        
nodes_do = tr_do.trace.nodes

with m_ind_do:
    print("factual", get_factual_indices())
    print(indices_of(nodes_do["X"]["value"]))
    print(indices_of(nodes_do["Y"]["value"]))
    print(indices_of(nodes_do["Z"]["value"]))


{'Z': tensor(1.)}


KeyboardInterrupt: 

In [46]:
antecedents={"X": torch.tensor(1.0), "Y": torch.tensor(1.0)}
alternatives={"X": torch.tensor(0.0), "Y": torch.tensor(0.0)}
supports=supports_independent.supports
antecedent_bias = 0.0 #-0.5
prefix = "__cause__"
consequents={"Z": torch.tensor(1.0)}
consequent_scale=0
factors = None
witnesses = {}
preemptions = None
witness_bias = 0.0


with MultiWorldCounterfactual() as mwc_independent:
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={"X": torch.tensor(1.0), "Y": torch.tensor(1.0)},
            consequents={"Z": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0), "Y": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):  
        # with SearchForExplanation(
        #     supports=supports_independent.supports,
        #     antecedents={"X": torch.tensor(1.0), "Z": torch.tensor(1.0)},
        #     consequents={"Y": torch.tensor(1.0)},
        #     witnesses={},
        #     alternatives={"X": torch.tensor(0.0), "Z": torch.tensor(0.0)},
        #     antecedent_bias=-0.5,
        #     consequent_scale=0,
        # ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace_independent:
                    model_three_independent()

trace_independent.trace.compute_log_prob

with mwc_independent:
    print(get_factual_indices())
    print(indices_of(trace_independent.trace.nodes["X"]["value"]))
    print(indices_of(trace_independent.trace.nodes["Z"]["value"]))
    print(indices_of(trace_independent.trace.nodes["Y"]["value"]))
    print(indices_of(trace_independent.trace.nodes["__cause____consequent_Z"]["value"]))

trace_independent.trace.nodes["__cause____consequent_Z"]["fn"].log_factor.shape

IndexSet({'X': {0}, 'Y': {0}})
IndexSet({'X': {0, 1, 2}})
IndexSet({})
IndexSet({'Y': {0, 1, 2}})
IndexSet({'Y': {0, 1, 2}})


torch.Size([3, 3, 1, 1, 1, 1])

In [33]:
def model_three_dependent():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(0.5))
    Z = pyro.sample("Z", dist.Bernoulli(torch.min(X,Y)))
    return {"X": X, "Y": Y, "Z": Z}

with ExtractSupports() as supports_independent:
    model_three_dependent()

with MultiWorldCounterfactual() as mwc_independent:  
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={"X": torch.tensor(1.0), "Y": torch.tensor(1.0)},
            consequents={"Z": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0), "Y": torch.tensor(0.0)},
            #antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=1):
                with pyro.poutine.trace() as trace_independent:
                    model_three_dependent()

trace_independent.trace.compute_log_prob

# with mwc_independent:
#     print(indices_of(trace_independent.trace.nodes["X"]["value"]))
#     print(indices_of(trace_independent.trace.nodes["Z"]["value"]))
#     print(indices_of(trace_independent.trace.nodes["Y"]["value"]))
#     print(indices_of(trace_independent.trace.nodes["__cause____consequent_Y"]["value"]))

# trace_independent.trace.nodes["__cause____consequent_Y"]["fn"].log_factor.shape

KeyboardInterrupt: 

In [None]:
def model_independent():
        X = pyro.sample("X", dist.Bernoulli(0.5))
        Y = pyro.sample("Y", dist.Bernoulli(0.5))

with ExtractSupports() as supports_independent:
    model_independent()

with MultiWorldCounterfactual() as mwc_independent:  
        with SearchForExplanation(
            supports=supports_independent.supports,
            antecedents={"X": torch.tensor(1.0)},
            consequents={"Y": torch.tensor(1.0)},
            witnesses={},
            alternatives={"X": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=3):
                with pyro.poutine.trace() as trace_independent:
                    model_independent()

In [3]:
def model_triple():
    X = pyro.sample("X", dist.Bernoulli(0.5))
    Y = pyro.sample("Y", dist.Bernoulli(X))
    Z = pyro.sample("Z", dist.Bernoulli(Y))

with ExtractSupports() as supports_triple:
    model_triple()

with MultiWorldCounterfactual() as mwc_triple:  
        with SearchForExplanation(
            supports=supports_triple.supports,
            antecedents={"Z": torch.tensor(1.0)},
            consequents={"X": torch.tensor(1.0)},
            witnesses={},
            alternatives={"Z": torch.tensor(0.0)},
            antecedent_bias=-0.5,
            consequent_scale=0,
        ):
            with pyro.plate("sample", size=3):
                with pyro.poutine.trace() as trace_triple:
                    model_triple()

trace_triple.trace.compute_log_prob


trace_triple.trace.nodes["X"]["value"]
trace_triple.trace.nodes["__cause____consequent_X"]["fn"].log_factor

tensor([[[[[0., 0., 0.]]]],



        [[[[-inf, 0., -inf]]]],



        [[[[0., -inf, 0.]]]]])

In [29]:
event_shape = (3,) #(3,)
plate_size = 4


factors = {
    "consequent": consequent_eq_neq(
        support=constraints.independent(constraints.real, 0),
        proposed_consequent=torch.Tensor([0.01]), 
        antecedents=["w"],
    )
}


# fake_w =  dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)).sample()

# @Factors(factors=factors)
@pyro.plate("data", size=plate_size, dim=-4)
def model_ce():
        w = pyro.sample("w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape)))
        # w = pyro.sample("w", dist.Normal(fake_w, 0.001))

        consequent = pyro.deterministic("consequent", w * torch.tensor(0.1))

        print("w", w.shape, "c", consequent.shape)
        assert w.shape == consequent.shape


antecedents = {
        "w": (
            torch.tensor(0.1).expand(event_shape),
            sufficiency_intervention(
                constraints.independent(constraints.real, len(event_shape)),
            ["w"]
            )
        )
    }

with MultiWorldCounterfactual() as mwc_ce:
    with do(actions = antecedents):
        with pyro.poutine.trace() as trace_ce: 
            model_ce()
    
print(trace_ce.trace.nodes.keys())
with mwc_ce:
    print(indices_of(trace_ce.trace.nodes["w"]["value"]))
    print(indices_of(trace_ce.trace.nodes["consequent"]["value"]))
    # print(indices_of(trace_ce.trace.nodes['__factor_consequent']["fn"].log_factor))
    # print(trace_ce.trace.nodes['__factor_consequent']["fn"].log_factor)


# nd = trace_ce.trace.nodes

# trace_ce.trace.compute_log_prob

# with mwc_ce:
#     eq_neq_log_probs_fact = gather(
#         nd["__factor_consequent"]["fn"].log_factor, IndexSet(**{"w": {0}},  event_dim = 0)
#     )

#     eq_neq_log_probs_nec = gather(
#         nd["__factor_consequent"]["fn"].log_factor, IndexSet(**{"w": {1}}, event_dim = 0)
#     )
    
#     consequent_suff = gather(
#         nd["consequent"]["value"], IndexSet(**{"w": {2}}, event_dim = 0 )
#     )


#     print("what's up", indices_of(nd["consequent"]["value"]))


#     eq_neq_log_probs_suff = gather(
#             nd["__factor_consequent"]["fn"].log_factor, IndexSet(**{"w": {2}})
#        )

#     assert eq_neq_log_probs_nec.shape == consequent_suff.shape

#     assert torch.equal(eq_neq_log_probs_suff, dist.Normal(0.0, .1).log_prob(consequent_suff - torch.tensor(.01)))
#     assert eq_neq_log_probs_nec.sum().exp() == 0           




w torch.Size([3, 4, 1, 1, 1, 3]) c torch.Size([3, 4, 1, 1, 1, 3])
odict_keys(['w', 'consequent'])
IndexSet({'w': {0, 1, 2, 3}})
IndexSet({'w': {0, 1, 2, 3}})


In [57]:
# event_shape = (3,) #(3,)
# plate_size = 4

# def test_consequent_eq_neq():
#     factors = {
#         "consequent": consequent_eq_neq(
#             support=constraints.independent(constraints.real, len(event_shape)),
#             proposed_consequent=torch.Tensor([0.01]), 
#             antecedents=["w"],
#             event_dim=len(event_shape),
#         )
#     }

#     @Factors(factors=factors)
#     @pyro.plate("data", size=plate_size, dim=-1)
#     def model_ce():
#         w = pyro.sample(
#             "w", dist.Normal(0, 0.1).expand(event_shape).to_event(len(event_shape))
#         )
#         #consequent = pyro.deterministic(
#         #    "consequent", w * 0.1, event_dim=len(event_shape)
#         #)
#         consequent = pyro.sample("consequent", dist.Delta(w * 0.1).to_event(len(event_shape)))

#         return consequent

#     antecedents = {
#         "w": (
#             torch.tensor(0.1).expand(event_shape),
#             sufficiency_intervention(
#                 constraints.independent(constraints.real, len(event_shape)), ["w"]
#             ),
#         )
#     }

#     with MultiWorldCounterfactual() as mwc:
#         with do(actions=antecedents):
#             with pyro.poutine.trace() as tr:
#                 model_ce()

#     tr.trace.compute_log_prob()
#     nd = tr.trace.nodes


#     with mwc:
#         eq_neq_log_probs_fact = gather(
#            nd["__factor_consequent"]["log_prob"], IndexSet(**{"w": {0}},  event_dim = 0)
#        )

#         eq_neq_log_probs_nec = gather(
#            nd["__factor_consequent"]["log_prob"], IndexSet(**{"w": {1}}, event_dim = 0)
#        )
        
#         consequent_suff = gather(
#             nd["consequent"]["value"], IndexSet(**{"w": {2}}, event_dim = 0 )
#         )


#         print(indices_of(nd["consequent"]["value"]))


#         eq_neq_log_probs_suff = gather(
#             nd["__factor_consequent"]["log_prob"], IndexSet(**{"w": {2}})
#        )

#     assert eq_neq_log_probs_nec.shape == consequent_suff.shape

#     assert torch.equal(eq_neq_log_probs_suff, dist.Normal(0.0, .1).log_prob(consequent_suff - torch.tensor(.01)))
#     assert eq_neq_log_probs_nec.sum().exp() == 0           
