In [1]:
import functools
import contextlib
import collections
from typing import Callable, Iterable, TypeVar, Mapping, List, Dict


from itertools import chain, combinations

import random

import pyro
import torch  

from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")

import pyro
import chirho
import pyro.distributions as dist
import pyro.infer
import torch
import pandas as pd

from chirho.counterfactual.handlers.counterfactual import (MultiWorldCounterfactual,
        Preemptions)
from chirho.counterfactual.handlers.explanation import (
    SearchForCause,
    consequent_differs,
    random_intervention,
    undo_split,
    uniform_proposal,
    ExplainCauses
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors, condition
from chirho.interventional.ops import Intervention, intervene
from chirho.interventional.handlers import do

In [2]:
# here is the handler for the causal explanation

def tensorize_dictionary(dictionary):
    return {k: torch.as_tensor(v) for k, v in dictionary.items()}

def boolean_constraints_from_list(list):
    return {k: pyro.distributions.constraints.boolean for k in list}


@contextlib.contextmanager
def Explanation_Evaluation( 
        consequents_observed: Dict[str, torch.Tensor],
        causal_candidates: List[str],
        condition_on_consequent = True,
        runs_n: int = 100,):

        consequents = list(consequents_observed.keys())
        consequents_observed = tensorize_dictionary(consequents_observed)
        # this needs to be replaced if nodes are not boolean
        causal_candidate_constraints = boolean_constraints_from_list(causal_candidates)

        # needed in order to check if always a part of the antecedent set is a part of an actual cause
        with MultiWorldCounterfactual() as mwc:
            with ExplainCauses(antecedents = causal_candidate_constraints, 
                      witnesses = causal_candidates, consequents = consequents,
                      antecedent_bias = .1,):
                if condition_on_consequent:
                    with condition(data = consequents_observed):
                        with pyro.plate("sample", runs_n):
                            with pyro.poutine.trace() as tr:
                                yield {"mwc": mwc, "tr" : tr}
                else:
                    with pyro.plate("sample", runs_n):
                        with pyro.poutine.trace() as tr:
                            yield {"mwc": mwc, "tr" : tr}

In [3]:
# two simple models

@pyro.infer.config_enumerate
def ff_disjunctive():
        u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
        u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

        match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
        lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
        forest_fire = pyro.deterministic("forest_fire", torch.max(match_dropped, lightning), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}
        
@pyro.infer.config_enumerate
def ff_conjunctive():
        u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
        u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

        match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
        lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
        forest_fire = pyro.deterministic("forest_fire", torch.min(match_dropped, lightning), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

In [4]:
# in the conditoned model, I'd expect large positive logits
# If both conditions are required for the fire to occur

def guide():
    pass

ff_conjunctive_conditioned = pyro.condition(ff_conjunctive, data={"forest_fire": torch.tensor(1.)})
pyro.infer.TraceEnum_ELBO().compute_marginals(ff_conjunctive_conditioned, guide)

OrderedDict([('u_match_dropped', Bernoulli(logits: 0.0)),
             ('u_lightning', Bernoulli(logits: 0.0))])

In [5]:
# handlers for the two models

consequents_observed={"forest_fire": torch.tensor(True)}
causal_candidates=["match_dropped", "lightning"]
causal_candidates_dict = {"match_dropped": 1., "lightning": 1.}
antecedent_prefix = "__antecedent"

exp_ff_con_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 1000,
)

exp_ff_dis_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 1000,
)

with exp_ff_con_handler as con_ff:
       ff_conjunctive()

with exp_ff_dis_handler as dis_ff:
        ff_disjunctive()

In [7]:
# define a new function using the handler, make sure it works
# ignore using con_ff2 for now

def exp_ff_conjunctive():
    exp_ff_con_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 1,)
    
    with exp_ff_con_handler as con_ff2:
        y = ff_conjunctive()
        return y
    

y = exp_ff_conjunctive()   

print(y) 
#print(con_ff['tr'].trace.nodes['forest_fire']['value'])

{'match_dropped': tensor([[[[[1.]]]],



        [[[[1.]]]]]), 'lightning': tensor([[[[[[0.]]]]],




        [[[[[0.]]]]]]), 'forest_fire': tensor([[[[[[0.]]]],



         [[[[0.]]]]],




        [[[[[0.]]]],



         [[[[0.]]]]]])}


In [8]:
# now try to use TraceEnum_ELBO to compute the marginals

ff_exp_ff_conjunctive_conditioned = pyro.condition(exp_ff_conjunctive, data={"forest_fire": torch.tensor(1.)})
pyro.infer.TraceEnum_ELBO().compute_marginals(ff_exp_ff_conjunctive_conditioned, guide)

RuntimeError: shape mismatch: value tensor of shape [2, 1, 1, 1, 1, 1, 1, 1, 1] cannot be broadcast to indexing result of shape [2, 1, 1, 1, 1, 1, 1]
                            Trace Shapes:                    
                             Param Sites:                    
                            Sample Sites:                    
                              sample dist                   |
                                    value                 1 |
                     u_match_dropped dist                 1 |
                                    value     2 1 1 1 1 1 1 |
                         u_lightning dist                 1 |
                                    value   2 1 1 1 1 1 1 1 |
                       match_dropped dist     2 1 1 1 1 1 1 |
                                    value     2 1 1 1 1 1 1 |
__antecedent__proposal_match_dropped dist                 1 |
                                    value 2 1 1 1 1 1 1 1 1 |