In [1]:
import contextlib
import random
from itertools import chain, combinations
from typing import Dict, List

import pandas as pd
import pyro
import pyro.distributions as dist
import pyro.infer
import torch
from chirho.observational.handlers import condition
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.counterfactual.handlers.explanation import ExplainCauses
                                            
from chirho.indexed.ops import (IndexSet, gather, indices_of) 
from chirho.interventional.handlers import do

In [2]:
def tensorize_dictionary(dictionary: Dict[str, float]) -> Dict[str, torch.Tensor]:
    return {k: torch.as_tensor(v) for k, v in dictionary.items()}

def boolean_constraints_from_list(list: List[str]) -> Dict[str, pyro.distributions.constraints.Constraint]:
    return {k: pyro.distributions.constraints.boolean for k in list}


@contextlib.contextmanager
def Explanation_Evaluation( 
        consequents_observed: Dict[str, torch.Tensor],
        causal_candidates: List[str],
        condition_on_consequent = True,
        runs_n: int = 100,):

        consequents = list(consequents_observed.keys())
        consequents_observed = tensorize_dictionary(consequents_observed)
        # this needs to be replaced if nodes are not boolean
        causal_candidate_constraints = boolean_constraints_from_list(causal_candidates)

        # needed in order to check if always a part of the antecedent set is a part of an actual cause
        with MultiWorldCounterfactual() as mwc:
            with ExplainCauses(antecedents = causal_candidate_constraints, 
                      witnesses = causal_candidates, consequents = consequents,
                      antecedent_bias = .1,):
                if condition_on_consequent:
                    with condition(data = consequents_observed):
                        with pyro.plate("sample", runs_n):
                            with pyro.poutine.trace() as tr:
                                yield {"mwc": mwc, "tr" : tr}
                else:
                    with pyro.plate("sample", runs_n):
                        with pyro.poutine.trace() as tr:
                            yield {"mwc": mwc, "tr" : tr}

## Implementing the original definition

The trace obtained using the above handler already contains information that can be used to implement the original definition. Such a query can be answered with only two core concepts: subset inclusion and log prob sum comparison. First, let's implement this definition. Later on, we'll move beyond some of the idiosyncracies involved here.

In [3]:
# getting a table from the trace
# post sampling conditioning on observations or interventions

def gather_observed(value, causal_candidates):
            
    _indices = [i for i in causal_candidates if i in indices_of(value, event_dim=0)]
    _int_can = gather(value, IndexSet(**{i: {0} for i in _indices}), event_dim=0,)
    return _int_can


def gather_intervened(value, causal_candidates):
        
    _indices = [i for i in causal_candidates if i in indices_of(value, event_dim=0)]
    _int_can = gather(value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
    return _int_can


def get_explanation_table(trace, mwc, causal_candidates, consequents, prior_contributions = None):

    trace.trace.compute_log_prob()
    
    table_dict = {}
    nodes = trace.trace.nodes

    for candidate in causal_candidates:
        with mwc:
            table_dict[f"obs_{candidate}"] = gather_observed(nodes[candidate]['value'],
                                                             causal_candidates).squeeze().tolist()
    
            
        table_dict[f"pint_{candidate}"] = nodes[f'__antecedent__proposal_{candidate}']['value'].squeeze().tolist()
        table_dict[f"apre_{candidate}"] = nodes[f'__antecedent_{candidate}']['value'].squeeze().tolist()
        table_dict[f"lp_apre_{candidate}"] = nodes[f"__antecedent_{candidate}"]['log_prob']

        table_dict[f"wpre_{candidate}"] = nodes[f'__witness_{candidate}']['value'].squeeze().tolist()

        # context used twice to make the table more legible (meaningful column ordering)
        with mwc:
            table_dict[f"int_{candidate}"] = gather_intervened(nodes[candidate]['value'],
                                                             causal_candidates).squeeze().tolist()
            

    for consequent in consequents:
        with mwc:
            table_dict[f"obs_{consequent}"] = gather_observed(nodes[consequent]["value"], causal_candidates).squeeze().tolist()
            table_dict[f"int_{consequent}"] = gather_intervened(nodes[consequent]["value"], causal_candidates).squeeze().tolist()    
            table_dict[f"lp_{consequent}"] = gather_intervened(nodes[f"__consequent_{consequent}"]["log_prob"], causal_candidates).squeeze().tolist()
        
    if prior_contributions is not None:
        for node in prior_contributions:
            table_dict[f"lp_{node}"] = gather_observed(nodes[node]["log_prob"], causal_candidates).squeeze().tolist()
            
    table_pd = pd.DataFrame(table_dict).drop_duplicates()

    # small cleanup:
    # remove rows where proposed interventions are the same as the observed values
    for candidate in causal_candidates:
        mask = table_pd[f'obs_{candidate}'] != table_pd[f'pint_{candidate}']
        table_pd = table_pd[mask]
      
    
    summands = [col for col in table_pd.columns if col.startswith('lp')]
    table_pd["sum_lp"] =  table_pd[summands].sum(axis = 1)
    table_pd.sort_values(by = "sum_lp", ascending = False, inplace = True)

    # some sanity checks
    for candidate in causal_candidates:
        
        # witness preempted nodes have the same observed and intervened values
        assert all(table_pd.loc[table_pd[f'wpre_{candidate}'] == 1, 
        f'obs_{candidate}'] == table_pd.loc[table_pd[f'wpre_{candidate}'] == 1,
                                            f'int_{candidate}'])

        # antecedent preemptions leave obs and int values unchanged
        # this might fail generally if a node is downstream from some 
        # other intervention, but let's keep this sanity check for the simple models 
        # for now        
        assert all(table_pd.loc[table_pd[f'apre_{candidate}'] == 1,
                                f'obs_{candidate}'] == 
                   table_pd.loc[table_pd[f'apre_{candidate}'] == 1,
                                f'int_{candidate}'])
        

    return table_pd


# brute force conditioning on the observed causal candidates
# needed to implement the original def, not needed later on

# post-sampling conditioning on observations
def get_conditioned_table(table, dict_of_nodes):
    conditioned_table = table.copy()
    for candidate in dict_of_nodes.keys():
        obs_column = f'obs_{candidate}'
        mask = (conditioned_table[obs_column] == dict_of_nodes[candidate])
        conditioned_table = conditioned_table[mask]
        conditioned_table.reset_index(drop=True, inplace=True)

    return conditioned_table

# post-sampling conditioning on selected variables being intervened
def get_intervened_table(table, dict_of_nodes):
    intervened_table = table.copy()
    for candidate in dict_of_nodes.keys():
        condition_string = f"`apre_{candidate}` == 0 and `wpre_{candidate}` == 0 and int_{candidate} == {dict_of_nodes[candidate]}"        
        intervened_table = intervened_table.query(condition_string)
        intervened_table.reset_index(drop=True, inplace=True)

    return intervened_table
    

# a few other utility functions

# from a dictionary of interventions generate all non-empty subdictionaries
# needed as a search space for subset-inclusion
def powerset(dct):
    keys = list(dct.keys())
    key_tuples = list(chain.from_iterable(combinations(keys, r) for r in range(len(keys) + 1)))[:-1]
    subdicts = [{k: dct[k] for k in tpl} for tpl in key_tuples]
    return subdicts

#  in a given row in a sample table, which nodes are active in an intervention considered in that row?
def active(node, row):
    return row[f"apre_{node}"] == 0 and row[f"wpre_{node}"] == 0 and row[f"obs_{node}"] != row[f"pint_{node}"]

def minimal_sets(set_list):
    inclusion_minimal = []
    for s1 in set_list:
        is_minimal = True
        for s2 in set_list:
            if s1 != s2 and s2.issubset(s1):
                is_minimal = False
        if is_minimal:
            inclusion_minimal.append(s1)
    return inclusion_minimal


# go through those rows where the intervened consequents differ from the observed ones
# keep track of subset-minimal interventions and the corresponding row log_probs.
# needed as there might be subset-minimal interventions that aren't in a log_prob-maximal row
def minimal_cause_sets(table, causal_candidates, get_values = False, return_logprobs = False):

    if isinstance(causal_candidates, dict):
        causal_candidates = list(causal_candidates.keys())
    
    subset_minimal_causes = []
    subset_minimal_causes_dicts = []
    subset_logprobs = []
    frozen_sets = set()
    changers = table[table['sum_lp']  > -1e8].copy()
    
    for _, row in changers.iterrows():
        active_set = {node for node in  causal_candidates if active(node, row)}
        active_set_frozen = frozenset(active_set)
        active_dict = {node: row[f"obs_{node}"] for node in active_set}
        if active_set_frozen not in frozen_sets:
            frozen_sets.add(active_set_frozen)
            subset_minimal_causes.append(active_set)
            subset_minimal_causes_dicts.append(active_dict)
            subset_logprobs.append(row['sum_lp'])


    # output, depending on what is needed, as specified by the arguments
    if not get_values and not return_logprobs:
        return subset_minimal_causes
    elif get_values and not return_logprobs:
        return subset_minimal_causes_dicts
    elif not get_values and return_logprobs:
        return subset_minimal_causes, subset_logprobs
    elif get_values and return_logprobs:
        return subset_minimal_causes_dicts, subset_logprobs

                

In [4]:
class ExplanationHalpern():

    def __init__(self, table, consequent_dict, causal_candidates_dict, print_report = True):
        self.table = table
        self.consequent_dict = consequent_dict
        self.causal_candidates_dict = causal_candidates_dict

        self.ex1A = self.ex1A_check(self.table, consequent_dict= self.consequent_dict, causal_candidates_dict= self.causal_candidates_dict)
        self.ex1B = self.ex1B_check(self.table, consequent_dict=self.consequent_dict, causal_candidates_dict= self.causal_candidates_dict)
        self.ex2 = self.ex2_check(self.table, consequent_dict= self.consequent_dict, causal_candidates_dict= self.causal_candidates_dict)
        self.ex3_4 = self.ex3_4_check(self.table, consequent_dict=self.consequent_dict, causal_candidates_dict=self.causal_candidates_dict)
        self.explanation = self.explanation_check()
        if print_report:
                print(f"explanation check: {self.explanation}, ", 
                f"causal_overlap: {self.ex1B}, ",
                f"minimality: {self.ex2}, ",
                f"possibility: {self.ex3_4[0]}, ", 
                f"non-triviality: {self.ex3_4[1]}")
        
    # Condition 1 requires that (A) in all contexts in which the causal candidates have the postulated values 
    # and the consequent is changed, there is at least one member of the causal candidates 
    # that is a member of a minimal actual cause of that change to the consequent.
    def ex1A_check(self, table, consequent_dict, causal_candidates_dict):

        # suppose not only consequent but also the causal candidates are observed    
        merged_dict = {**consequent_dict, **causal_candidates_dict}
        
        table_conditioned = get_conditioned_table(table, merged_dict)

        # check if causal candidates coincide with a minimal cause set
        minimal_conditioned = minimal_cause_sets(table_conditioned, causal_candidates_dict)
        parthood_flag = any(any(candidate in s for candidate in causal_candidates_dict.keys()) for s in minimal_conditioned)   
        return parthood_flag, minimal_conditioned
    
    # Moreover, it requires that (B) for any context under consideration, intervening on all causal candidates 
    # to have the postulated value leads to the consequent.
    def ex1B_check(self, table, causal_candidates_dict, consequent_dict):
        table_intervened = get_intervened_table(table, causal_candidates_dict)
        return all((table_intervened[f"int_{key}"] == consequent_dict[key]).all() for key in consequent_dict.keys()) 

    # The second condition of Halpern's definition is that the causal candidate set 
    # be a subset-minimal one satisfying both `ex1A` and `ex1B`.
    def ex2_check(self, table,  consequent_dict, causal_candidates_dict,):
        sub_candidates = powerset(causal_candidates_dict)
        minimality_flag = True
        for sub_candidate in sub_candidates:
            ex1A_flag = self.ex1A_check(table, consequent_dict= consequent_dict, causal_candidates_dict= sub_candidate)[0]
            ex1B_flag = self.ex1B_check(table, consequent_dict = consequent_dict, causal_candidates_dict= sub_candidate)
            
            if ex1A_flag and ex1B_flag:
                minimality_flag = False
                break
        return minimality_flag
    
    # The remaining checks are whether the explanation is possible (given the consequent) and non-trivial.
    def ex3_4_check(self, table, consequent_dict, causal_candidates_dict):

        merged_dict = {**consequent_dict, **causal_candidates_dict}
        table_conditioned = get_conditioned_table(table, merged_dict)
        ex3 = table_conditioned.shape[0] > 0
        ex4 = any(any(table[f"obs_{key}"] != causal_candidates_dict[key]) for key in causal_candidates_dict.keys())

        return ex3, ex4
    
    def explanation_check(self):
        return self.ex1A[0] and self.ex1B and self.ex2 and self.ex3_4[0] 

# Examples

## Forest fire example

Example 7.1.2. from Halpern's *Actual Causality*. In the conjunctive model if both a lightning occurs and a match is dropped, a forest fire results, but both factors are required. In the disjunctive model, each of these factors individually is sufficient for the forest fire.

Suppose all contexts are available and no setting is excluded by what the agent knows about the world. 
- In the conjunctive model, the joint nodes are an explanation of forest fire, none of the individual ones is. This is in contrast with actual causality claims, as each of the nodes is an actual cause, but the conjunction is not.
- In the disjunctive model, the reverse is true. Each node is an explanation of forest fire, but the conjunction is not. This is in contrast with actual causality claims, as none of the individual nodes is an actual cause, but the conjunction is.

In [5]:
@pyro.infer.config_enumerate
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() & lightning.bool()),
                                      event_dim=0)

@pyro.infer.config_enumerate
def ff_disjunctive():
        u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
        u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

        match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
        lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
        forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() | lightning.bool()).bool(), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}


In [6]:
consequents_observed={"forest_fire": torch.tensor(True)}
causal_candidates=["match_dropped", "lightning"]
causal_candidates_dict = {"match_dropped": 1., "lightning": 1.}
antecedent_prefix = "__antecedent"

exp_ff_con_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 1000,
)

exp_ff_dis_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 1000,
)

with exp_ff_con_handler as con_ff:
        ff_conjunctive()

with exp_ff_dis_handler as dis_ff:
        ff_disjunctive()

table_con_ff = get_explanation_table(con_ff["tr"], con_ff["mwc"], causal_candidates, consequents=["forest_fire"])
table_dis_ff = get_explanation_table(dis_ff["tr"], dis_ff["mwc"], causal_candidates, consequents=["forest_fire"])

In [7]:
# conjunction in the conjunctive model
con_ff_explanation = ExplanationHalpern(table_con_ff, {"forest_fire": True}, causal_candidates_dict)
# one of the nodes in the conjunctive model
con_ff_explanation_match = ExplanationHalpern(table_con_ff, {"forest_fire": True}, {"match_dropped": 1.})

# conjunction in the disjunctive model
dis_ff_explanation = ExplanationHalpern(table_dis_ff, {"forest_fire": True}, causal_candidates_dict)
# one of the nodes in the disjunctive model
dis_ff_explanation_match = ExplanationHalpern(table_dis_ff, {"forest_fire": True}, {"match_dropped": 1.})

explanation check: True,  causal_overlap: True,  minimality: True,  possibility: True,  non-triviality: True
explanation check: False,  causal_overlap: False,  minimality: True,  possibility: True,  non-triviality: True
explanation check: False,  causal_overlap: True,  minimality: False,  possibility: True,  non-triviality: True
explanation check: True,  causal_overlap: True,  minimality: True,  possibility: True,  non-triviality: True


## Extended forest fire

This is based on example 7.1.4. In April, given the electrical storm in May, the forest would have caught fire in May (and not in June). However, given the storm, if there had been an electrical storm only in May, the forest
would not have caught fire at all; if there had been an electrical storm only in June, it would have caught fire in June. The model has five endogenous variables: `ar` for *April rains*, 
`esm` for *electric storms in May*, `esj` for *electric storms in June*, `ffm` for *forest fire in May*, `ffj` for *forest fire in June* and `ff` for *forest fire either in May or in June (or both)*. 

In [8]:
def ff_extended():
    u_ar = pyro.sample("u_ar", dist.Bernoulli(0.5))
    u_esm = pyro.sample("u_esm", dist.Bernoulli(0.5))
    u_esj = pyro.sample("u_esj", dist.Bernoulli(0.5))

    ar = pyro.deterministic("ar", u_ar, event_dim=0)
    esm = pyro.deterministic("esm", u_esm, event_dim=0)
    esj = pyro.deterministic("esj", u_esj, event_dim=0)

    ffm = pyro.deterministic("ffmt", esm  * (1 - ar), event_dim=0).float()
    ffj = pyro.deterministic("ffj", (esj * torch.max(ar, (1 - esm))), event_dim=0).float()
    ff = pyro.deterministic("ff", torch.max(ffm, ffj), event_dim=0).float()
    
    return {"u_ar": u_ar, "u_esm": u_esm, "u_esj": u_esj, 
            "ar": ar, "esm": esm, "esj": esj, "ffm": ffm, "ffj": ffj, "ff": ff}


In [9]:
consequents_observed={"ff": 1.}
causal_candidates_dict = {"esj": 1., "esm": 1.,  "ar": 0.}
causal_candidates= list(causal_candidates_dict.keys())

exp_ff_ext_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 10000,
)

with exp_ff_ext_handler as ext_ff:
    ff_extended()

table_ext_ff = get_explanation_table(ext_ff["tr"], ext_ff["mwc"], causal_candidates, consequents=["ff"])


ext_ff_1_explanation = ExplanationHalpern(table_ext_ff, {"ff": 1.}, {"esm": 1.,  "ar": 0.})
ext_ff_1_explanation = ExplanationHalpern(table_ext_ff, {"ff": 1.}, {"esj": 1.})
ext_ff_1_explanation = ExplanationHalpern(table_ext_ff, {"ff": 1.}, {"esj": 1, "esm": 1.,  "ar": 0.})
ext_ff_1_explanation = ExplanationHalpern(table_ext_ff, {"ff": 1.},{'ar': 1.0, 'esm': 0.0}) 

explanation check: True,  causal_overlap: True,  minimality: True,  possibility: True,  non-triviality: True
explanation check: True,  causal_overlap: True,  minimality: True,  possibility: True,  non-triviality: True
explanation check: False,  causal_overlap: True,  minimality: False,  possibility: True,  non-triviality: True
explanation check: False,  causal_overlap: False,  minimality: True,  possibility: True,  non-triviality: True


Ideally, however, we would like to explore the space of possible explanations - in this particular case, confirming that these are the only two explanations. If we approach it from the perspective of the query of the type "is a given setting of a certain collection of endogenous variables an explanation of a given outcome", it seems like the way to go is to manually propose candidates and run a separate query every time.

In [10]:
# explore all 27 possible explanations
# to confirm that these are the only 
# possible explanations

explanatory_status = []
candidates = []

nodes = ["ar", "esm", "esj"]
subsets = [[]]
for node in nodes:
        subsets.extend([subset + [node] for subset in subsets])


for subset in subsets:
        for _ in range(50):
                random_setting = [random.choice([0., 1.]) 
                          for _ in range(len(subset))]
                candidate = {var: val for var, val in zip(subset, random_setting)}

                if candidate in candidates:
                        continue

                candidates.append(candidate)
                
                explanatory_status.append(ExplanationHalpern(table_ext_ff, {"ff": 1.}, 
                                                             candidate, 
                                                             print_report = False).explanation)

explanation_search = pd.DataFrame({"candidates": candidates,
                                   "explanatory_status": explanatory_status})

print(explanation_search)
    

                             candidates  explanatory_status
0                                    {}               False
1                           {'ar': 1.0}               False
2                           {'ar': 0.0}               False
3                          {'esm': 1.0}               False
4                          {'esm': 0.0}               False
5               {'ar': 1.0, 'esm': 0.0}               False
6               {'ar': 0.0, 'esm': 0.0}               False
7               {'ar': 1.0, 'esm': 1.0}               False
8               {'ar': 0.0, 'esm': 1.0}                True
9                          {'esj': 0.0}               False
10                         {'esj': 1.0}                True
11              {'ar': 0.0, 'esj': 1.0}               False
12              {'ar': 1.0, 'esj': 1.0}               False
13              {'ar': 1.0, 'esj': 0.0}               False
14              {'ar': 0.0, 'esj': 0.0}               False
15             {'esm': 0.0, 'esj': 1.0} 

# Going beyond the original definition

While actual causes do not have to be explanations in Halpern's sense, one key intuition behind why one needs a notion of an explanation as defined by Halpern is that one is interested in narrowing down the search space to things that would **overlap** with actual causes if they were true and guiding our investigation in the search of actual causes. We propose an explanation, and then investigate the node values to see what the actual causes are. If this is your motivation, the approach we present allows you to focusing on possible actual causes directly and skip this extra step, by finding possible states that would **be** actual causes, if true. In fact, the handler that we have already explores the possible actual causes for us, and instead of running a wider search we can just check the potential actual causes for whether they are explanations, if this feature matters to us for some reason.

In [11]:
table_ext_ff_conditioned = get_conditioned_table(table_ext_ff, {"ff": 1.})
display(table_ext_ff_conditioned[table_ext_ff_conditioned['sum_lp'] > -1e7])

possible_actual_causes = minimal_cause_sets(table_ext_ff_conditioned, ["ar", "esm", "esj"], get_values=True)
print(possible_actual_causes)

for pac in possible_actual_causes:
    if ExplanationHalpern(table_ext_ff, {"ff": 1.}, pac, print_report = False).explanation:
        print("I would count as Halpern's explanation!", pac)

Unnamed: 0,obs_esj,pint_esj,apre_esj,lp_apre_esj,wpre_esj,int_esj,obs_esm,pint_esm,apre_esm,lp_apre_esm,...,obs_ar,pint_ar,apre_ar,lp_apre_ar,wpre_ar,int_ar,obs_ff,int_ff,lp_ff,sum_lp
0,1.0,0.0,0,-0.916291,0,0.0,0.0,1.0,1,-0.510826,...,0.0,1.0,1,-0.510826,1,0.0,1.0,0.0,0.0,-1.937942
1,1.0,0.0,0,-0.916291,0,0.0,0.0,1.0,1,-0.510826,...,1.0,0.0,1,-0.510826,1,1.0,1.0,0.0,0.0,-1.937942
2,1.0,0.0,0,-0.916291,0,0.0,0.0,1.0,1,-0.510826,...,0.0,1.0,1,-0.510826,0,0.0,1.0,0.0,0.0,-1.937942
3,0.0,1.0,1,-0.510826,1,0.0,1.0,0.0,1,-0.510826,...,0.0,1.0,0,-0.916291,0,1.0,1.0,0.0,0.0,-1.937942
4,0.0,1.0,1,-0.510826,1,0.0,1.0,0.0,0,-0.916291,...,0.0,1.0,1,-0.510826,1,0.0,1.0,0.0,0.0,-1.937942
5,0.0,1.0,1,-0.510826,1,0.0,1.0,0.0,0,-0.916291,...,0.0,1.0,1,-0.510826,0,0.0,1.0,0.0,0.0,-1.937942
6,0.0,1.0,1,-0.510826,0,0.0,1.0,0.0,1,-0.510826,...,0.0,1.0,0,-0.916291,0,1.0,1.0,0.0,0.0,-1.937942
7,0.0,1.0,1,-0.510826,0,0.0,1.0,0.0,0,-0.916291,...,0.0,1.0,1,-0.510826,0,0.0,1.0,0.0,0.0,-1.937942
8,1.0,0.0,0,-0.916291,0,0.0,1.0,0.0,1,-0.510826,...,1.0,0.0,1,-0.510826,1,1.0,1.0,0.0,0.0,-1.937942
9,1.0,0.0,0,-0.916291,0,0.0,0.0,1.0,1,-0.510826,...,1.0,0.0,1,-0.510826,0,1.0,1.0,0.0,0.0,-1.937942


[{'esj': 1.0}, {'ar': 0.0}, {'esm': 1.0}, {'ar': 0.0, 'esm': 1.0}, {'esj': 1.0, 'ar': 0.0}, {'esj': 1.0, 'esm': 1.0}, {'esj': 1.0, 'ar': 0.0, 'esm': 0.0}]
I would count as Halpern's explanation! {'esj': 1.0}
I would count as Halpern's explanation! {'ar': 0.0, 'esm': 1.0}


Notice also that the bias used in the effect handlers results in smaller sets being ranked higher, which also can be useful in your search. This feature may be used in interaction with our priors about which states of possible causes are more likely, we just need to add the corresponding log probabilities to the log prob sum. For a simple illustration, let us get back to the disjunctive model of forest fire, with the caveat that now we think a lightning is much less likely.

In [15]:
@pyro.infer.config_enumerate
def ff_disjunctive_uneven():
        u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.6))
        u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.4))

        match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped, event_dim=0)
        lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
        forest_fire = pyro.deterministic("forest_fire", torch.max(match_dropped, lightning), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

ff_disjunctive_uneven()

{'match_dropped': tensor(1.),
 'lightning': tensor(0.),
 'forest_fire': tensor(1.)}

In [16]:
consequents_observed={"forest_fire": 1.}
causal_candidates=["lightning","match_dropped"]

exp_ff_dis__uneven_handler =  Explanation_Evaluation(
    consequents_observed=consequents_observed,
    causal_candidates= causal_candidates,
    condition_on_consequent = False,
    runs_n= 4000,
)

with exp_ff_dis__uneven_handler as dis_ff_uneven:
        ff_disjunctive_uneven()

dis_ff_uneven["tr"].trace.compute_log_prob()


In [17]:

table_dis_ff_uneven = get_explanation_table(dis_ff_uneven["tr"], dis_ff_uneven["mwc"], causal_candidates, consequents=["forest_fire"],
                                             prior_contributions=["u_match_dropped", "u_lightning"])

table_dis_ff_uneven_conditioned = get_conditioned_table(table_dis_ff_uneven, {"forest_fire": 1.})

possible_actual_causes = minimal_cause_sets(table_dis_ff_uneven_conditioned, ["match_dropped", "lightning"], get_values=True, return_logprobs=True)


print(possible_actual_causes)

# the proper way to read this: "just match dropped and no lightning" are ranked higher than
# "match_dropped and lightning", which are ranked higher than "no match dropped and just the lightning"

([{'match_dropped': 1.0}, {'lightning': 1.0, 'match_dropped': 1.0}, {'lightning': 1.0}], [-2.4487674832344055, -3.259697675704956, -3.2596977949142456])
