In [1]:
import functools
import contextlib
import collections
from typing import Callable, Iterable, TypeVar, Mapping, List, Dict


from itertools import chain, combinations

import random

import pyro
import torch  # noqa: F401

from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")

import pyro
import chirho
import pyro.distributions as dist
import pyro.infer
import torch
import pandas as pd

from chirho.counterfactual.handlers.counterfactual import (MultiWorldCounterfactual,
        Preemptions)
from chirho.counterfactual.handlers.explanation import (
    SearchForCause,
    consequent_differs,
    random_intervention,
    undo_split,
    uniform_proposal,
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors, condition
from chirho.interventional.ops import Intervention, intervene
from chirho.interventional.handlers import do

In [2]:
@contextlib.contextmanager
def ExplainCauses(
    antecedents: Mapping[str, Intervention[T]]
    | Mapping[str, pyro.distributions.constraints.Constraint],
    witnesses: Mapping[str, Intervention[T]] | Iterable[str],
    consequents: Mapping[str, Callable[[T], float | torch.Tensor]]
    | Iterable[str],
    *,
    antecedent_bias: float = 0.0,
    witness_bias: float = 0.0,
    consequent_eps: float = -1e8,
    antecedent_prefix: str = "__antecedent_",
    witness_prefix: str = "__witness_",
    consequent_prefix: str = "__consequent_",
):
    """
    Effect handler for causal explanation.

    :param antecedents: A mapping from antecedent names to interventions.
    :param witnesses: A mapping from witness names to interventions.
    :param consequents: A mapping from consequent names to factor functions.
    """
    if isinstance(
        next(iter(antecedents.values())),
        pyro.distributions.constraints.Constraint,
    ):
        antecedents = {
            a: random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}")
            for a, s in antecedents.items()
        }

    if not isinstance(witnesses, collections.abc.Mapping):
        witnesses = {
            w: undo_split(antecedents=list(antecedents.keys()))
            for w in witnesses
        }

    if not isinstance(consequents, collections.abc.Mapping):
        consequents = {
            c: consequent_differs(
                antecedents=list(antecedents.keys()), eps=consequent_eps
            )
            for c in consequents
        }

    if len(consequents) == 0:
        raise ValueError("must have at least one consequent")

    if len(antecedents) == 0:
        raise ValueError("must have at least one antecedent")

    if set(consequents.keys()) & set(antecedents.keys()):
        raise ValueError(
            "consequents and possible antecedents must be disjoint"
        )

    if set(consequents.keys()) & set(witnesses.keys()):
        raise ValueError("consequents and possible witnesses must be disjoint")

    antecedent_handler = SearchForCause(
        actions=antecedents, bias=antecedent_bias, prefix=antecedent_prefix
    )
    witness_handler = Preemptions(
        actions=witnesses, bias=witness_bias, prefix=witness_prefix
    )
    consequent_handler = Factors(factors=consequents, prefix=consequent_prefix)

    with antecedent_handler, witness_handler, consequent_handler:
            yield

In [3]:

def tensorize_dictionary(dictionary):
    return {k: torch.as_tensor(v) for k, v in dictionary.items()}

def boolean_constraints_from_list(list):
    return {k: pyro.distributions.constraints.boolean for k in list}

#pyro.distributions.constraints.independent(support), len(event_shape))

# trace-handling helper functions

def gather_observed(value, causal_candidates):
            
    _indices = [
            i for i in causal_candidates if i in indices_of(value, event_dim=0)
        ]
    _int_can = gather(
    value, IndexSet(**{i: {0} for i in _indices}), event_dim=0,)
    return _int_can


def gather_intervened(value, causal_candidates):
    
        
    _indices = [
            i for i in causal_candidates if i in indices_of(value, event_dim=0)
        ]
    _int_can = gather(
    value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
    return _int_can




def get_table(trace, mwc, antecedents, witnesses, consequents):

    values_table = {}
    nodes = trace.trace.nodes
    

    if isinstance(antecedents, dict):
        antecedents_list = list(antecedents.keys())
    else:
        antecedents_list = antecedents

    with mwc:

        for antecedent_str in antecedents_list:
                
            obs_ant = gather_observed(nodes[antecedent_str]["value"], antecedents_list, witnesses)
            int_ant = gather_intervened(nodes[antecedent_str]["value"], antecedents_list, witnesses)

            values_table[f"{antecedent_str}_obs"] = obs_ant.squeeze().tolist()
            values_table[f"{antecedent_str}_int"] = int_ant.squeeze().tolist()
            
            apr_ant = nodes[f"__antecedent_{antecedent_str}"]["value"]
            values_table[f"apr_{antecedent_str}"] = apr_ant.squeeze().tolist()
            
            values_table[f"apr_{antecedent_str}_lp"] = nodes[f"__antecedent_{antecedent_str}"]["fn"].log_prob(apr_ant)

        if witnesses:
            for candidate in witnesses:
                obs_candidate = gather_observed(nodes[candidate]["value"], antecedents_list, witnesses)
                int_candidate = gather_intervened(nodes[candidate]["value"], antecedents_list, witnesses)
                values_table[f"{candidate}_obs"] = obs_candidate.squeeze().tolist()
                values_table[f"{candidate}_int"] = int_candidate.squeeze().tolist()

                wpr_con = nodes[f"__witness_{candidate}"]["value"]
                values_table[f"wpr_{candidate}"] = wpr_con.squeeze().tolist()
            

        for consequent in consequents:
            
            obs_consequent = gather_observed(nodes[consequent]["value"], antecedents_list, witnesses)
            int_consequent = gather_intervened(nodes[consequent]["value"], antecedents_list, witnesses)
            con_lp = nodes[f"__consequent_{consequent}"]['fn'].log_prob(torch.tensor(1)) #TODO: this feels like a hack
            _indices_lp = [
            i for i in antecedents_list + witnesses if i in indices_of(con_lp)]
            int_con_lp = gather(con_lp, IndexSet(**{i: {1} for i in _indices_lp}), event_dim=0,)      


            values_table[f"{consequent}_obs"] = obs_consequent.squeeze().tolist()   
            values_table[f"{consequent}_int"] = int_consequent.squeeze().tolist()
            values_table[f"{consequent}_lp"] = int_con_lp.squeeze().tolist()   

    values_df = pd.DataFrame(values_table)

    values_df.drop_duplicates(inplace=True)

    summands = [col for col in values_df.columns if col.endswith('lp')]
    values_df["sum_log_prob"] =  values_df[summands].sum(axis = 1)
    values_df.sort_values(by = "sum_log_prob", ascending = False, inplace = True)

    return values_df
    

In [4]:
def get_explanation_table(trace, mwc, causal_candidates, consequents):

    trace.trace.compute_log_prob()
    
    table_dict = {}
    nodes = trace.trace.nodes

    for candidate in causal_candidates:
        with mwc:
            table_dict[f"obs_{candidate}"] = gather_observed(nodes[candidate]['value'],
                                                             causal_candidates).squeeze().tolist()
    
            
        table_dict[f"pint_{candidate}"] = nodes[f'__antecedent__proposal_{candidate}']['value'].squeeze().tolist()
        table_dict[f"apre_{candidate}"] = nodes[f'__antecedent_{candidate}']['value'].squeeze().tolist()
        table_dict[f"lp_apre_{candidate}"] = nodes[f"__antecedent_{candidate}"]['log_prob']

        table_dict[f"wpre_{candidate}"] = nodes[f'__witness_{candidate}']['value'].squeeze().tolist()

        # context used twice to make the table more legible (meaningful column ordering)
        with mwc:
            table_dict[f"int_{candidate}"] = gather_intervened(nodes[candidate]['value'],
                                                             causal_candidates).squeeze().tolist()
            

    for consequent in consequents:
        with mwc:
            table_dict[f"obs_{consequent}"] = gather_observed(nodes[consequent]["value"], causal_candidates).squeeze().tolist()
            table_dict[f"int_{consequent}"] = gather_intervened(nodes[consequent]["value"], causal_candidates).squeeze().tolist()
            table_dict[f"lp_{consequent}"] = gather_intervened(nodes[f"__consequent_{consequent}"]["log_prob"], causal_candidates).squeeze().tolist()
        


    table_pd = pd.DataFrame(table_dict).drop_duplicates()

    # remove rows where proposed interventions are the same as the observed values
    for candidate in causal_candidates:
        mask = table_pd[f'obs_{candidate}'] != table_pd[f'pint_{candidate}']
        table_pd = table_pd[mask]
      
    
    summands = [col for col in table_pd.columns if col.startswith('lp')]
    table_pd["sum_lp"] =  table_pd[summands].sum(axis = 1)
    table_pd.sort_values(by = "sum_lp", ascending = False, inplace = True)

    # some sanity checks
    for candidate in causal_candidates:
        
        # witness preempted nodes have the same observed and intervened values
        assert all(table_pd.loc[table_pd[f'wpre_{candidate}'] == 1, 
        f'obs_{candidate}'] == table_pd.loc[table_pd[f'wpre_{candidate}'] == 1,
                                            f'int_{candidate}'])

        # antecedent preemptions leave obs and int values unchanged
        # this might fail generally if a node is downstream from some 
        # other intervention, but let's keep this sanity check for the simple models 
        # for now        
        assert all(table_pd.loc[table_pd[f'apre_{candidate}'] == 1,
                                f'obs_{candidate}'] == 
                   table_pd.loc[table_pd[f'apre_{candidate}'] == 1,
                                f'int_{candidate}'])


    
    return table_pd
    

In [5]:
@contextlib.contextmanager
def Explanation_Evaluation( 
        consequents_observed: Dict[str, torch.Tensor],
        causal_candidates: List[str],
        runs_n: int = 100,):

        consequents = list(consequents_observed.keys())
        consequents_observed = tensorize_dictionary(consequents_observed)
        # this needs to be replaced if nodes are not boolean
        causal_candidate_constraints = boolean_constraints_from_list(causal_candidates)

        # needed to check if always a part of the antecedent set is a part of an actual cause
        with MultiWorldCounterfactual() as mwc:
            with ExplainCauses(antecedents = causal_candidate_constraints, 
                      witnesses = causal_candidates, consequents = consequents,
                      antecedent_bias = .1):
                with condition(data = {"forest_fire": torch.tensor(True)}):
                    with pyro.plate("sample", runs_n):
                        with pyro.poutine.trace() as tr:
                            yield {"mwc": mwc, "tr" : tr}

In [6]:
@pyro.infer.config_enumerate
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() & lightning.bool()),
                                      event_dim=0)

@pyro.infer.config_enumerate
def ff_disjunctive():
        u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
        u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

        match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
        lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
        forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() | lightning.bool()).bool(), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}


In [7]:
consequents_observed={"forest_fire": torch.tensor(True)}
causal_candidates=["match_dropped", "lightning"]
antecedent_prefix = "__antecedent"

exp_ff_con_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    runs_n= 1000,
)

def exp_ff_conjunctive(consequents_observed, causal_candidates, runs_n=1000):
    exp_ff_con_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    runs_n= 1000,
    )
    with exp_ff_con_handler as con_ff:
        ff_conjunctive()

def guide():
    pass

# works with bare model
ff_conjunctive_conditioned = pyro.condition(ff_conjunctive, data={"forest_fire": torch.tensor(True)})
pyro.infer.TraceEnum_ELBO().compute_marginals(ff_conjunctive_conditioned,guide)

#doesn't work here
#pyro.infer.TraceEnum_ELBO().compute_marginals(exp_ff_conjunctive,guide)


OrderedDict([('u_match_dropped', Bernoulli(logits: 0.0)),
             ('u_lightning', Bernoulli(logits: 0.0))])

In [8]:
#this seems hacky

adam = pyro.optim.Adam({'lr': 0.03})
elbo = pyro.infer.Trace_ELBO()

def guide_ff_con(consequents_observed, causal_candidates, runs_n=1000):
    pass

svi = pyro.infer.SVI(exp_ff_conjunctive, guide_ff_con, adam, loss=elbo)

num_iterations = 100

for j in range(10):
    loss = svi.step(consequents_observed, causal_candidates, runs_n = 1)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss))
        

[iteration 0001] loss: 178200024451.3400




In [9]:
#and doesn't seem to work anyway
#predictive = pyro.infer.Predictive(exp_ff_conjunctive,guide_ff_con, num_samples= 100)

In [10]:
# let's just run without inference 
# and for now do it brute force

with exp_ff_con_handler as con_ff:
        ff_conjunctive()

In [13]:
mwc = con_ff["mwc"]
tr = con_ff["tr"]

table = get_explanation_table(tr, mwc, causal_candidates, consequents=["forest_fire"])


# hacky brute force conditioning
conditioned_table = table.copy()
for candidate in causal_candidates:
    obs_column = f'obs_{candidate}'
    mask = (table[obs_column] == 1.0)
    conditioned_table = conditioned_table[mask]

display(conditioned_table)


  conditioned_table = conditioned_table[mask]


Unnamed: 0,obs_match_dropped,pint_match_dropped,apre_match_dropped,lp_apre_match_dropped,wpre_match_dropped,int_match_dropped,obs_lightning,pint_lightning,apre_lightning,lp_apre_lightning,wpre_lightning,int_lightning,obs_forest_fire,int_forest_fire,lp_forest_fire,sum_lp
5,1.0,0.0,1,-0.510826,0,1.0,1.0,0.0,0,-0.916291,0,0.0,True,False,0.0,-1.427116
291,1.0,0.0,0,-0.916291,0,0.0,1.0,0.0,1,-0.510826,0,1.0,True,False,0.0,-1.427116
81,1.0,0.0,0,-0.916291,0,0.0,1.0,0.0,1,-0.510826,1,1.0,True,False,0.0,-1.427116
342,1.0,0.0,1,-0.510826,1,1.0,1.0,0.0,0,-0.916291,0,0.0,True,False,0.0,-1.427116
346,1.0,0.0,0,-0.916291,1,1.0,1.0,0.0,0,-0.916291,0,0.0,True,False,0.0,-1.832581
37,1.0,0.0,0,-0.916291,0,0.0,1.0,0.0,0,-0.916291,0,0.0,True,False,0.0,-1.832581
52,1.0,0.0,0,-0.916291,0,0.0,1.0,0.0,0,-0.916291,1,1.0,True,False,0.0,-1.832581
79,1.0,0.0,1,-0.510826,1,1.0,1.0,0.0,1,-0.510826,1,1.0,True,True,-100000000.0,-100000000.0
199,1.0,0.0,1,-0.510826,1,1.0,1.0,0.0,1,-0.510826,0,1.0,True,True,-100000000.0,-100000000.0
150,1.0,0.0,1,-0.510826,0,1.0,1.0,0.0,1,-0.510826,1,1.0,True,True,-100000000.0,-100000000.0


In [15]:
def active(node, row):
    return row[f"apre_{node}"] == 0 and row[f"wpre_{node}"] == 0 and row[f"obs_{node}"] != row[f"pint_{node}"]

def minimal_sets(set_list):
    inclusion_minimal = []
    for s1 in set_list:
        is_minimal = True
        for s2 in set_list:
            if s1 != s2 and s2.issubset(s1):
                is_minimal = False
        if is_minimal:
            inclusion_minimal.append(s1)
    return inclusion_minimal


def minimal_cause_sets(table):
    
    subset_minimal_causes = []
    subset_logprobs = []
    changers = table[table['sum_lp']  > -1e7].copy()
    
    for _, row in changers.iterrows():
        active_set = {node for node in causal_candidates if active(node, row)}
        if active_set not in subset_minimal_causes:
            subset_minimal_causes.append(active_set)
            subset_logprobs.append(row['sum_lp'])
        subset_minimal_causes = minimal_sets(subset_minimal_causes)
        
    return subset_minimal_causes

minimal_conditioned = minimal_cause_sets(conditioned_table)

display(minimal_conditioned)


antecedent_candidates = causal_candidates

print(
all(any(candidate in s for candidate in antecedent_candidates) for s in minimal_conditioned)
)

# any( in set for candidate in ) for s in list_of_sets

# # Check if all sets in list_of_sets have at least one element from list2
# all_have_intersection = all(has_intersection_with_list2(s, list2) for s in list_of_sets)



[{'lightning'}, {'match_dropped'}]

['match_dropped', 'lightning']


In [5]:
# this reduces the actual causality check to checking a property of the resulting sums of log probabilities
# for the antecedent preemption and the consequent differs nodes

def ac_check(trace, mwc, antecedents, witnesses, consequents):

     table = get_table(trace, mwc, antecedents, witnesses, consequents)
     
     if (list(table['sum_log_prob'])[0]<= -1e8):
          print("No resulting difference to the consequent in the sample.")
          return
     
     winners = table[table['sum_log_prob'] == table['sum_log_prob'].max()]
     

     ac_flags = []
     for index, row in winners.iterrows():
          active_antecedents = []
          for antecedent in antecedents:
               if row[f"apr_{antecedent}"] == 0:
                    active_antecedents.append(antecedent)

          ac_flags.append(set(active_antecedents) == set(antecedents))

     if not any(ac_flags):
          print("The antecedent set is not minimal.")
     else:
          print("The antecedent set is an actual cause.")

     return any(ac_flags)

In [7]:
# processing of traces
# to identify potential explanations

# Is always a part of the antecedent
# a part of an actual cause of the consequent?
def sufficient_causality_checkA(output_dict, antecedents = None, witnesses = None, consequents = None):

    if antecedents is None:
        antecedents = output_dict['antecedents']
    if consequents is None:
        consequents = list(output_dict['consequents_observed'].keys())
    
    endogenous_nodes = output_dict['endogenous_nodes']
    causal_candidates = output_dict['causal_candidates']

    if witnesses is None:
        witnesses = [node for node in endogenous_nodes if (
                                          node not in antecedents.keys() and 
                                          node not in consequents)]

    table = get_table(output_dict['tr_sufficiency_A'],
                      output_dict['mwc_sufficiency_A'],
                      antecedents, witnesses, 
                      consequents)

    # a bit hacky, but adding antecedents to conditioning
    # within the first batch of handlers
    # led to tensor broadcasting issues
    for antecedent_str in antecedents.keys():
        table = table[table[f"{antecedent_str}_obs"] == antecedents[antecedent_str]]
    
    table = table[table['sum_log_prob'] > -1e8]
    
    # we need to check set inclusion minimality of cause sets
    # as there might be inclusion minimal sets that are not log-prob-sum minimal
    # just because they have a higher cardinality
    candidate_sets = []
    for i, row in table.iterrows():
        candidate_set = set()
        for node in causal_candidates:
            if row[f"{node}_int"] != row[f"{node}_obs"]:
                candidate_set.add(node)
        candidate_sets.append(candidate_set)
        
    actual_cause_sets = minimal_sets(candidate_sets)
        
    frozensets = [frozenset(s) for s in actual_cause_sets]
    unique_actual_cause_sets = [set(f) for f in set(frozensets)]
        
    sufficiency_flag = any(key in ac_set for key in antecedents.keys() for ac_set in actual_cause_sets)

    return  unique_actual_cause_sets, sufficiency_flag

In [8]:
# does fixing the antecedent always lead to a change in the consequent?

def sufficient_causality_checkB(output_dict, mwc = None, trace = None, antecedents = None, consequents = None):
    
    if antecedents is None:
        antecedents = output_dict['antecedents']
    if consequents is None:
        consequents = list(output_dict['consequents_observed'].keys())
    
    if trace is None:    
        trace = output_dict["tr_sufficiency_B"]    
    
    if mwc is None:
        mwc = output_dict["mwc_sufficiency_B"]
    
    outcome_df = pd.DataFrame()
    with mwc:
        for consequent in consequents:
            value = trace.trace.nodes[consequent]["value"]
            _indices = [
                    i for i in list(antecedents.keys()) if i in indices_of(value, event_dim=0)
                ]
            _int_con = gather(
            value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
            outcome_df[consequent] = _int_con.squeeze().tolist()
        
    return ((outcome_df) == True).all().all()

In [9]:
# is the antecedent set a minimal set that 
# satisfies these two conditions?

def minimal_sufficiency_check(output_dict):
    antecedent_candidates = powerset(output_dict['antecedents'])
    
    a_checks = []
    b_checks = []
    for i, antecedent_candidate in enumerate(antecedent_candidates):
        a_checks.append(sufficient_causality_checkA(
            output_dict,antecedents = antecedent_candidate,            
                witnesses = [node for node in output_dict['endogenous_nodes'] if (
                            node not in antecedent_candidate.keys() and 
                            node not in output_dict['consequents_observed'].keys())])[1]
        )
    
        b_checks.append( sufficient_causality_checkB(output_dict, 
                                                     mwc = output_dict["mwc_candidate"][i],
                                                     trace = output_dict["tr_candidate"][i],
                                            antecedents = antecedent_candidate)
            )
        
    minimality = not any(a and b for a,b in zip(a_checks, b_checks))       
    
    return {"minimality": minimality, "a_checks": a_checks, "b_checks": b_checks, "antecedent_candidates": antecedent_candidates}


In [10]:
# is the explanation possible and nontrivial?

def possibility_and_nontriviality_check(output_dict):
    
    trace = output_dict["tr_priors"]
    antecedents = output_dict['antecedents']
    consequents_observed = output_dict['consequents_observed']
    
    reqs = {**antecedents, **consequents_observed}

    reqs_outcome = pd.DataFrame()

    for req in reqs:
        reqs_outcome[req] = trace.trace.nodes[req]["value"]
           
    possibility = (reqs_outcome == 1.0).all(axis=1).any()
    nontriviality = (reqs_outcome.iloc[:, :2] == 0.0).any(axis=1).any()
    
    return possibility, nontriviality

In [11]:
def explanation_check(output_object):
    
    sufficiencyA = sufficient_causality_checkA(output_object)[1]
    sufficiencyB = sufficient_causality_checkB(output_object)
    minimal_sufficiency = minimal_sufficiency_check(output_object)['minimality']
    possibility, nontriviality = possibility_and_nontriviality_check(output_object)
    
    return all([sufficiencyA, sufficiencyB, minimal_sufficiency, possibility, nontriviality])


# Examples

## Forest fire example

Example 7.1.2. from Halpern's *Actual Causality*. First, all contexts available, no settings excluded by what the agent knows about the world. In the conjunctive model, the joint nodes are an explanation of forest fire, none of the individual ones is. In the disjunctive model, the reverse is true.


In [12]:
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() & lightning.bool()),
                                      event_dim=0)

def ff_disjunctive():
        u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
        u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

        match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
        lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
        forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() | lightning.bool()).bool(), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}


In [13]:
exp_ff_con_handler =  Explanation_Evaluation(
    model = ff_conjunctive,
    antecedents={"match_dropped": 1.0, "lightning": 1.0},
    consequents_observed={"forest_fire": torch.tensor(True)},
    endogenous_nodes=["match_dropped", "lightning", "forest_fire"],
)
    
with exp_ff_con_handler as con_ff:
    ff_conjunctive()
    
explanation_check(con_ff)

True

In [14]:
exp_ff_con_separate_handler =  Explanation_Evaluation(
    model = ff_conjunctive,
    antecedents={"match_dropped": 1.0},
    consequents_observed={"forest_fire": torch.tensor(True)},
    endogenous_nodes=["match_dropped", "lightning", "forest_fire"],
)
     
with exp_ff_con_separate_handler as con_ff_separate:
    ff_conjunctive()
    
explanation_check(con_ff_separate)

False

In [15]:
exp_ff_dis_handler =  Explanation_Evaluation(
    model = ff_disjunctive,
    antecedents={"match_dropped": 1.0, "lightning": 1.0},
    consequents_observed={"forest_fire": torch.tensor(True)},
    endogenous_nodes=["match_dropped", "lightning", "forest_fire"],
) 
    
with exp_ff_dis_handler as dis_ff:
    ff_disjunctive()

explanation_check(dis_ff)

False

In [16]:
exp_ff_dis_separate_handler =  Explanation_Evaluation(
    model = ff_disjunctive,
    antecedents={"match_dropped": 1.0},
    consequents_observed={"forest_fire": torch.tensor(True)},
    endogenous_nodes=["match_dropped", "lightning", "forest_fire"],
) 

with exp_ff_dis_separate_handler as dis_ff_separate:
    ff_disjunctive()

explanation_check(dis_ff_separate)

True

### Extended forest fire

In April, given the electrical storm in May, the forest would have caught fire in May (and not in June). However, given the storm, if there had been an electrical storm only in May, the forest
would not have caught fire at all; if there had been an electrical storm only in June, it would have caught fire in June. The model has five endogenous variables: `ar` for *April rains*, 
`esm` for *electric storms in May*, `esj` for *electric storms in June*, `ffm` for *forest fire in May*, `ffj` for *forest fire in June* and `ff` for *forest fire either in May or in June (or both)*. 

In [17]:
def ff_extended():
    u_ar = pyro.sample("u_ar", dist.Bernoulli(0.5))
    u_esm = pyro.sample("u_esm", dist.Bernoulli(0.5))
    u_esj = pyro.sample("u_esj", dist.Bernoulli(0.5))

    ar = pyro.deterministic("ar", u_ar, event_dim=0)
    esm = pyro.deterministic("esm", u_esm, event_dim=0)
    esj = pyro.deterministic("esj", u_esj, event_dim=0)

    ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
    ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()

    ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()

    return {"u_ar": u_ar, "u_esm": u_esm, "u_esj": u_esj, 
            "ar": ar, "esm": esm, "esj": esj, "ffm": ffm, "ffj": ffj, "ff": ff}


In [18]:
# with no excluded settings
# one explanation for ff is esj

exp_ff_extended_handler =  Explanation_Evaluation(
    model = ff_extended,
    antecedents={"esj": 1.},
    consequents_observed={"ff": 1.},
    endogenous_nodes=["ar", "esm", "esj"],
) 

with exp_ff_extended_handler as ext_ff1:
    ff_extended()

explanation_check(ext_ff1)

  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()


True

In [19]:
# another is esm=1 and ar=0

exp_ff_extended_handler2 =  Explanation_Evaluation(
    model = ff_extended,
    antecedents= {"esm": 1., "ar": 0.},
    consequents_observed={"ff": 1.},
    endogenous_nodes=["ar", "esm", "esj"],
) 

with exp_ff_extended_handler2 as ext_ff2:
    ff_extended()

explanation_check(ext_ff2)

  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()


True

In [20]:
# explore all 27 possible explanations
# to confirm that these are the only 
# possible explanations

explanatory_status = []
candidates = []

nodes = ["ar", "esm", "esj"]
subsets = [[]]
for node in nodes:
        subsets.extend([subset + [node] for subset in subsets])


for subset in subsets:
        for _ in range(50):
                random_setting = [random.choice([0., 1.]) 
                          for _ in range(len(subset))]
                candidate = {var: val for var, val in zip(subset, random_setting)}

                if candidate in candidates:
                        continue

                candidates.append(candidate)

                _ext_handler =  Explanation_Evaluation(
                                model = ff_extended,
                                antecedents= candidate,
                                consequents_observed={"ff": 1.},
                                endogenous_nodes=["ar", "esm", "esj"],
                                ) 
                
                with _ext_handler as _ext_obj:
                        ff_extended()

                explanatory_status.append(explanation_check(_ext_obj))

explanation_search = pd.DataFrame({"candidates": candidates,
                                   "explanatory_status": explanatory_status})

print(explanation_search)



                

  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logi

  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffm = pyro.deterministic(

                             candidates  explanatory_status
0                                    {}               False
1                           {'ar': 0.0}               False
2                           {'ar': 1.0}               False
3                          {'esm': 1.0}               False
4                          {'esm': 0.0}               False
5               {'ar': 1.0, 'esm': 1.0}               False
6               {'ar': 0.0, 'esm': 1.0}                True
7               {'ar': 1.0, 'esm': 0.0}               False
8               {'ar': 0.0, 'esm': 0.0}               False
9                          {'esj': 1.0}                True
10                         {'esj': 0.0}               False
11              {'ar': 1.0, 'esj': 1.0}               False
12              {'ar': 0.0, 'esj': 0.0}               False
13              {'ar': 0.0, 'esj': 1.0}               False
14              {'ar': 1.0, 'esj': 0.0}               False
15             {'esm': 0.0, 'esj': 1.0} 

  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffm = pyro.deterministic("ffm", torch.logical_and(esm, ~ ar.bool()), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ffj = pyro.deterministic("ffj", torch.logical_and(esj, (ar.bool() | ~ esm.bool())), event_dim=0).float()
  ff = pyro.deterministic("ff", torch.logical_or(ffm, ffj), event_dim=0).float()
