In [1]:
import functools
import contextlib
import collections
from typing import Callable, Iterable, TypeVar, Mapping, List, Dict


from itertools import chain, combinations

import random

import pyro
import torch  

from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")

import pyro
import chirho
import pyro.distributions as dist
import pyro.infer
import torch
import pandas as pd

from chirho.counterfactual.handlers.counterfactual import (MultiWorldCounterfactual,
        Preemptions)
from chirho.counterfactual.handlers.explanation import (
    SearchForCause,
    consequent_differs,
    random_intervention,
    undo_split,
    uniform_proposal,
    ExplainCauses
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors, condition
from chirho.interventional.ops import Intervention, intervene
from chirho.interventional.handlers import do

In [2]:
def tensorize_dictionary(dictionary):
    return {k: torch.as_tensor(v) for k, v in dictionary.items()}

def boolean_constraints_from_list(list):
    return {k: pyro.distributions.constraints.boolean for k in list}


@contextlib.contextmanager
def Explanation_Evaluation( 
        consequents_observed: Dict[str, torch.Tensor],
        causal_candidates: List[str],
        condition_on_consequent = True,
        runs_n: int = 100,):

        consequents = list(consequents_observed.keys())
        consequents_observed = tensorize_dictionary(consequents_observed)
        # this needs to be replaced if nodes are not boolean
        causal_candidate_constraints = boolean_constraints_from_list(causal_candidates)

        # needed in order to check if always a part of the antecedent set is a part of an actual cause
        with MultiWorldCounterfactual() as mwc:
            with ExplainCauses(antecedents = causal_candidate_constraints, 
                      witnesses = causal_candidates, consequents = consequents,
                      antecedent_bias = .1,):
                if condition_on_consequent:
                    with condition(data = consequents_observed):
                        with pyro.plate("sample", runs_n):
                            with pyro.poutine.trace() as tr:
                                yield {"mwc": mwc, "tr" : tr}
                else:
                    with pyro.plate("sample", runs_n):
                        with pyro.poutine.trace() as tr:
                            yield {"mwc": mwc, "tr" : tr}

## Implementing the original definition

The trace obtained using the above handler already contains information that can be used to implement the original definition. Such a query can be answered with only two core concepts: subset inclusion and log prob sum comparison. First, let's implement this definition. Later on, we'll move beyond some of the idiosyncracies involved here.

In [3]:
# getting a table from the trace
# post sampling conditioning on observations or interventions

def gather_observed(value, causal_candidates):
            
    _indices = [i for i in causal_candidates if i in indices_of(value, event_dim=0)]
    _int_can = gather(value, IndexSet(**{i: {0} for i in _indices}), event_dim=0,)
    return _int_can


def gather_intervened(value, causal_candidates):
        
    _indices = [i for i in causal_candidates if i in indices_of(value, event_dim=0)]
    _int_can = gather(value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
    return _int_can


def get_explanation_table(trace, mwc, causal_candidates, consequents):

    trace.trace.compute_log_prob()
    
    table_dict = {}
    nodes = trace.trace.nodes

    for candidate in causal_candidates:
        with mwc:
            table_dict[f"obs_{candidate}"] = gather_observed(nodes[candidate]['value'],
                                                             causal_candidates).squeeze().tolist()
    
            
        table_dict[f"pint_{candidate}"] = nodes[f'__antecedent__proposal_{candidate}']['value'].squeeze().tolist()
        table_dict[f"apre_{candidate}"] = nodes[f'__antecedent_{candidate}']['value'].squeeze().tolist()
        table_dict[f"lp_apre_{candidate}"] = nodes[f"__antecedent_{candidate}"]['log_prob']

        table_dict[f"wpre_{candidate}"] = nodes[f'__witness_{candidate}']['value'].squeeze().tolist()

        # context used twice to make the table more legible (meaningful column ordering)
        with mwc:
            table_dict[f"int_{candidate}"] = gather_intervened(nodes[candidate]['value'],
                                                             causal_candidates).squeeze().tolist()
            

    for consequent in consequents:
        with mwc:
            table_dict[f"obs_{consequent}"] = gather_observed(nodes[consequent]["value"], causal_candidates).squeeze().tolist()
            table_dict[f"int_{consequent}"] = gather_intervened(nodes[consequent]["value"], causal_candidates).squeeze().tolist()    
            table_dict[f"lp_{consequent}"] = gather_intervened(nodes[f"__consequent_{consequent}"]["log_prob"], causal_candidates).squeeze().tolist()
        
            # # a slightly hacky way to avoid shape errors in the extended forest fire example
            # # where the index of ff is empty
            # if len(table_dict[f"obs_{consequent}"]) == 2:
            #     table_dict[f"obs_ff"] = table_dict[f"obs_{consequent}"][0]
            #     table_dict[f"int_ff"] = table_dict[f"obs_{consequent}"][1]
            #     table_dict[f"lp_{consequent}"] = table_dict[f"lp_{consequent}"][1]
        
    table_pd = pd.DataFrame(table_dict).drop_duplicates()

    # remove rows where proposed interventions are the same as the observed values
    for candidate in causal_candidates:
        mask = table_pd[f'obs_{candidate}'] != table_pd[f'pint_{candidate}']
        table_pd = table_pd[mask]
      
    
    summands = [col for col in table_pd.columns if col.startswith('lp')]
    table_pd["sum_lp"] =  table_pd[summands].sum(axis = 1)
    table_pd.sort_values(by = "sum_lp", ascending = False, inplace = True)

    # some sanity checks
    for candidate in causal_candidates:
        
        # witness preempted nodes have the same observed and intervened values
        assert all(table_pd.loc[table_pd[f'wpre_{candidate}'] == 1, 
        f'obs_{candidate}'] == table_pd.loc[table_pd[f'wpre_{candidate}'] == 1,
                                            f'int_{candidate}'])

        # antecedent preemptions leave obs and int values unchanged
        # this might fail generally if a node is downstream from some 
        # other intervention, but let's keep this sanity check for the simple models 
        # for now        
        assert all(table_pd.loc[table_pd[f'apre_{candidate}'] == 1,
                                f'obs_{candidate}'] == 
                   table_pd.loc[table_pd[f'apre_{candidate}'] == 1,
                                f'int_{candidate}'])

    return table_pd


# brute force conditioning on the observed causal candidates
# needed as adding them to condition statements leads to shape errors
# needed to implement the original def, not needed later on

# post-sampling conditioning on observations
def get_conditioned_table(table, dict_of_nodes):
    conditioned_table = table.copy()
    for candidate in dict_of_nodes.keys():
        obs_column = f'obs_{candidate}'
        mask = (conditioned_table[obs_column] == dict_of_nodes[candidate])
        conditioned_table = conditioned_table[mask]
        conditioned_table.reset_index(drop=True, inplace=True)

    return conditioned_table

# post-sampling conditioning on selected variables being intervened
def get_intervened_table(table, dict_of_nodes):
    intervened_table = table.copy()
    for candidate in dict_of_nodes.keys():
        condition_string = f"`apre_{candidate}` == 0 and `wpre_{candidate}` == 0 and int_{candidate} == {dict_of_nodes[candidate]}"        
        intervened_table = intervened_table.query(condition_string)
        intervened_table.reset_index(drop=True, inplace=True)

    return intervened_table
    

# a few other utility functions
def powerset(dct):
    keys = list(dct.keys())
    key_tuples = list(chain.from_iterable(combinations(keys, r) for r in range(len(keys) + 1)))[:-1]
    subdicts = [{k: dct[k] for k in tpl} for tpl in key_tuples]
    return subdicts

def active(node, row):
    return row[f"apre_{node}"] == 0 and row[f"wpre_{node}"] == 0 and row[f"obs_{node}"] != row[f"pint_{node}"]

def minimal_sets(set_list):
    inclusion_minimal = []
    for s1 in set_list:
        is_minimal = True
        for s2 in set_list:
            if s1 != s2 and s2.issubset(s1):
                is_minimal = False
        if is_minimal:
            inclusion_minimal.append(s1)
    return inclusion_minimal


def minimal_cause_sets(table, causal_candidates_dict):
    
    subset_minimal_causes = []
    # TODO remove if really redundant
    subset_logprobs = []
    changers = table[table['sum_lp']  > -1e7].copy()
    
    for _, row in changers.iterrows():
        active_set = {node for node in causal_candidates_dict.keys() if active(node, row)}
        if active_set not in subset_minimal_causes:
            subset_minimal_causes.append(active_set)
    # TODO remove if really redudnant
    #        subset_logprobs.append(row['sum_lp'])
        subset_minimal_causes = minimal_sets(subset_minimal_causes)
        
    return subset_minimal_causes


In [4]:
class ExplanationHalpern():

    def __init__(self, table, consequent_dict, causal_candidates_dict):
        self.table = table
        self.consequent_dict = consequent_dict
        self.causal_candidates_dict = causal_candidates_dict

        self.ex1A = self.ex1A_check(self.table, consequent_dict= self.consequent_dict, causal_candidates_dict= self.causal_candidates_dict)
        self.ex1B = self.ex1B_check(self.table, consequent_dict=self.consequent_dict, causal_candidates_dict= self.causal_candidates_dict)
        self.ex2 = self.ex2_check(self.table, consequent_dict= self.consequent_dict, causal_candidates_dict= self.causal_candidates_dict)
        self.ex3_4 = self.ex3_4_check(self.table, consequent_dict=self.consequent_dict, causal_candidates_dict=self.causal_candidates_dict)
        self.explanation = self.explanation_check()
        print(f"explanation check: {self.explanation}, ", 
              f"causal_overlap: {self.ex1B}, ",
               f"minimality: {self.ex2}, ",
                f"possibility: {self.ex3_4[0]}, ", 
                f"non-triviality: {self.ex3_4[1]}")
    
    # Condition 1 requires that (A) in all contexts in which the causal candidates have the postulated values 
    # and the consequent is changed, there is at least one member of the causal candidates 
    # that is a member of a minimal actual cause of that change to the consequent.
    def ex1A_check(self, table, consequent_dict, causal_candidates_dict):

        # suppose not only consequent but also the causal candidates are observed    
        merged_dict = {**consequent_dict, **causal_candidates_dict}
        
        table_conditioned = get_conditioned_table(table, merged_dict)

        # check if causal candidates coincide with a minimal cause set
        minimal_conditioned = minimal_cause_sets(table_conditioned, causal_candidates_dict)
        parthood_flag = any(any(candidate in s for candidate in causal_candidates_dict.keys()) for s in minimal_conditioned)   
        return parthood_flag, minimal_conditioned
    
    # Moreover, it requires that (B) for any context under consideration, intervening on all causal candidates 
    # to have the postulated value leads to the consequent.
    def ex1B_check(self, table, causal_candidates_dict, consequent_dict):
        table_intervened = get_intervened_table(table, causal_candidates_dict)
        return all((table_intervened[f"int_{key}"] == consequent_dict[key]).all() for key in consequent_dict.keys()) 

    # The second condition of Halpern's definition is that the causal candidate set 
    # be a subset-minimal one satisfying both `ex1A` and `ex1B`.
    def ex2_check(self, table,  consequent_dict, causal_candidates_dict,):
        sub_candidates = powerset(causal_candidates_dict)
        minimality_flag = True
        for sub_candidate in sub_candidates:
            ex1A_flag = self.ex1A_check(table, consequent_dict= consequent_dict, causal_candidates_dict= sub_candidate)[0]
            ex1B_flag = self.ex1B_check(table, consequent_dict = consequent_dict, causal_candidates_dict= sub_candidate)
            
            if ex1A_flag and ex1B_flag:
                minimality_flag = False
                break
        return minimality_flag
    
    # The remaining checks are whether the explanation is possible (given the consequent) and non-trivial.
    def ex3_4_check(self, table, consequent_dict, causal_candidates_dict):

        merged_dict = {**consequent_dict, **causal_candidates_dict}
        table_conditioned = get_conditioned_table(table, merged_dict)
        ex3 = table_conditioned.shape[0] > 0
        ex4 = any(any(table[f"obs_{key}"] != causal_candidates_dict[key]) for key in causal_candidates_dict.keys())

        return ex3, ex4
    
    def explanation_check(self):
        return self.ex1A[0] and self.ex1B and self.ex2 and self.ex3_4[0] 

# Examples

## Forest fire example

Example 7.1.2. from Halpern's *Actual Causality*. In the conjunctive model if both a lightning occurs and a match is dropped, a forest fire results, but both factors are required. In the disjunctive model, each of these factors individually is sufficient for the forest fire.

Suppose all contexts are available and no setting is excluded by what the agent knows about the world. 
- In the conjunctive model, the joint nodes are an explanation of forest fire, none of the individual ones is. This is in contrast with actual causality claims, as each of the nodes is an actual cause, but the conjunction is not.
- In the disjunctive model, the reverse is true. Each node is an explanation of forest fire, but the conjunction is not. This is in contrast with actual causality claims, as none of the individual nodes is an actual cause, but the conjunction is.

In [5]:
@pyro.infer.config_enumerate
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() & lightning.bool()),
                                      event_dim=0)

@pyro.infer.config_enumerate
def ff_disjunctive():
        u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
        u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

        match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
        lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
        forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() | lightning.bool()).bool(), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}


In [6]:
consequents_observed={"forest_fire": torch.tensor(True)}
causal_candidates=["match_dropped", "lightning"]
causal_candidates_dict = {"match_dropped": 1., "lightning": 1.}
antecedent_prefix = "__antecedent"

exp_ff_con_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 1000,
)

exp_ff_dis_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 1000,
)

with exp_ff_con_handler as con_ff:
        ff_conjunctive()

with exp_ff_dis_handler as dis_ff:
        ff_disjunctive()

table_con_ff = get_explanation_table(con_ff["tr"], con_ff["mwc"], causal_candidates, consequents=["forest_fire"])
table_dis_ff = get_explanation_table(dis_ff["tr"], dis_ff["mwc"], causal_candidates, consequents=["forest_fire"])

In [7]:
# TraceENUM_ELBO works with bare model, doesn't seem to work 
# with Explanation_Evaluation

# def exp_ff_conjunctive():
#     exp_ff_con_handler =  Explanation_Evaluation(
#     consequents_observed=consequents_observed,
#     causal_candidates= causal_candidates,
#     runs_n= 1000,
#     )
#     with exp_ff_con_handler as con_ff:
#         ff_conjunctive()

# def guide():
#     pass

# works with bare model
#ff_conjunctive_conditioned = pyro.condition(ff_conjunctive, data={"forest_fire": torch.tensor(True)})
#pyro.infer.TraceEnum_ELBO().compute_marginals(ff_conjunctive_conditioned, guide)

#doesn't work here
#pyro.infer.TraceEnum_ELBO().compute_marginals(exp_ff_conjunctive,guide)

# will proceed with sampling from the model


In [8]:
# conjunction in the conjunctive model
con_ff_explanation = ExplanationHalpern(table_con_ff, {"forest_fire": True}, causal_candidates_dict)
# one of the nodes in the conjunctive model
con_ff_explanation_match = ExplanationHalpern(table_con_ff, {"forest_fire": True}, {"match_dropped": 1.})

# conjunction in the disjunctive model
dis_ff_explanation = ExplanationHalpern(table_dis_ff, {"forest_fire": True}, causal_candidates_dict)
# one of the nodes in the disjunctive model
dis_ff_explanation_match = ExplanationHalpern(table_dis_ff, {"forest_fire": True}, {"match_dropped": 1.})

explanation check: True,  causal_overlap: True,  minimality: True,  possibility: True,  non-triviality: True
explanation check: False,  causal_overlap: False,  minimality: True,  possibility: True,  non-triviality: True
explanation check: False,  causal_overlap: True,  minimality: False,  possibility: True,  non-triviality: True
explanation check: True,  causal_overlap: True,  minimality: True,  possibility: True,  non-triviality: True


### Extended forest fire

This is based on example 7.1.4. In April, given the electrical storm in May, the forest would have caught fire in May (and not in June). However, given the storm, if there had been an electrical storm only in May, the forest
would not have caught fire at all; if there had been an electrical storm only in June, it would have caught fire in June. The model has five endogenous variables: `ar` for *April rains*, 
`esm` for *electric storms in May*, `esj` for *electric storms in June*, `ffm` for *forest fire in May*, `ffj` for *forest fire in June* and `ff` for *forest fire either in May or in June (or both)*. 

In [9]:
def ff_extended():
    u_ar = pyro.sample("u_ar", dist.Bernoulli(0.5))
    u_esm = pyro.sample("u_esm", dist.Bernoulli(0.5))
    u_esj = pyro.sample("u_esj", dist.Bernoulli(0.5))

    ar = pyro.deterministic("ar", u_ar, event_dim=0)
    esm = pyro.deterministic("esm", u_esm, event_dim=0)
    esj = pyro.deterministic("esj", u_esj, event_dim=0)

    ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  
    ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).squeeze().float()
   
    ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
    
    return {"u_ar": u_ar, "u_esm": u_esm, "u_esj": u_esj, 
            "ar": ar, "esm": esm, "esj": esj, "ffm": ffm, "ffj": ffj, "ff": ff}


In [23]:
# with no excluded settings
# one explanation for ff is esj


consequents_observed={"ff": 1.}
#causal_candidates_1= ["ar", "esm"]
causal_candidates_dict_1 = {"esm": 1.,  "ar": 0.}
causal_candidates_1= list(causal_candidates_dict_1.keys())


exp_ff_ext1_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates_1,
    condition_on_consequent = False,
    runs_n= 1000,
)

with exp_ff_ext1_handler as ext_ff_1:
    ff_extended()


nodes = ext_ff_1['tr'].trace.nodes
mwc = ext_ff_1['mwc']

print(nodes.keys())

with mwc:
    value = nodes['ff']['value']
    print(value.shape)
    _indices = [i for i in causal_candidates_1 if i in indices_of(value, event_dim=0)]
    _int_can = gather(value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
    print(indices_of(value, event_dim=0))
    print(_int_can.shape)
    #print(gather_observed(nodes['ff']['value'], causal_candidates_1).shape)

odict_keys(['u_ar', 'u_esm', 'u_esj', '__antecedent__proposal_ar', '__antecedent_ar', '__witness_ar', 'ar', '__antecedent__proposal_esm', '__antecedent_esm', '__witness_esm', 'esm', 'esj', 'ffm', 'ffj', '__consequent_ff', 'ff'])
torch.Size([2, 2, 1, 2, 2, 1000])
IndexSet({'ar': {0, 1}, 'esm': {0, 1}})
torch.Size([1, 1, 1, 2, 2, 1000])


  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).squeeze().float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()


In [None]:


table_ext_ff_1 = get_explanation_table(ext_ff_1["tr"], ext_ff_1["mwc"], causal_candidates_1, consequents=["ff"])

display(table_ext_ff_1)

In [None]:

ext_ff_1_explanation = ExplanationHalpern(table_ext_ff_1, {"ff": 1.}, causal_candidates_dict_1)
#

In [None]:


exp_ff_dis_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 1000,
)

with exp_ff_con_handler as con_ff:
        ff_conjunctive()


exp_ff_extended_handler =  Explanation_Evaluation(
    model = ff_extended,
    antecedents={"esj": 1.},
    consequents_observed={"ff": 1.},
    endogenous_nodes=["ar", "esm", "esj"],
) 

with exp_ff_extended_handler as ext_ff1:
    ff_extended()

explanation_check(ext_ff1)

In [19]:
# another is esm=1 and ar=0

exp_ff_extended_handler2 =  Explanation_Evaluation(
    model = ff_extended,
    antecedents= {"esm": 1., "ar": 0.},
    consequents_observed={"ff": 1.},
    endogenous_nodes=["ar", "esm", "esj"],
) 

with exp_ff_extended_handler2 as ext_ff2:
    ff_extended()

explanation_check(ext_ff2)

  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()


True

In [20]:
# explore all 27 possible explanations
# to confirm that these are the only 
# possible explanations

explanatory_status = []
candidates = []

nodes = ["ar", "esm", "esj"]
subsets = [[]]
for node in nodes:
        subsets.extend([subset + [node] for subset in subsets])


for subset in subsets:
        for _ in range(50):
                random_setting = [random.choice([0., 1.]) 
                          for _ in range(len(subset))]
                candidate = {var: val for var, val in zip(subset, random_setting)}

                if candidate in candidates:
                        continue

                candidates.append(candidate)

                _ext_handler =  Explanation_Evaluation(
                                model = ff_extended,
                                antecedents= candidate,
                                consequents_observed={"ff": 1.},
                                endogenous_nodes=["ar", "esm", "esj"],
                                ) 
                
                with _ext_handler as _ext_obj:
                        ff_extended()

                explanatory_status.append(explanation_check(_ext_obj))

explanation_search = pd.DataFrame({"candidates": candidates,
                                   "explanatory_status": explanatory_status})

print(explanation_search)



                

  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logi

  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic(

                             candidates  explanatory_status
0                                    {}               False
1                           {'ar': 0.0}               False
2                           {'ar': 1.0}               False
3                          {'esm': 1.0}               False
4                          {'esm': 0.0}               False
5               {'ar': 1.0, 'esm': 1.0}               False
6               {'ar': 0.0, 'esm': 1.0}                True
7               {'ar': 1.0, 'esm': 0.0}               False
8               {'ar': 0.0, 'esm': 0.0}               False
9                          {'esj': 1.0}                True
10                         {'esj': 0.0}               False
11              {'ar': 1.0, 'esj': 1.0}               False
12              {'ar': 0.0, 'esj': 0.0}               False
13              {'ar': 0.0, 'esj': 1.0}               False
14              {'ar': 1.0, 'esj': 0.0}               False
15             {'esm': 0.0, 'esj': 1.0} 

  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
