In [29]:
import functools
import contextlib
import collections
from typing import Callable, Iterable, TypeVar, Mapping, List, Dict

import random

import pyro
import torch  # noqa: F401

from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")

import pyro
import chirho
import pyro.distributions as dist
import pyro.infer
import torch
import pandas as pd

from chirho.counterfactual.handlers.counterfactual import (MultiWorldCounterfactual,
        Preemptions)
from chirho.counterfactual.handlers.explanation import (
    SearchForCause,
    consequent_differs,
    random_intervention,
    undo_split,
    uniform_proposal,
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors, condition
from chirho.interventional.ops import Intervention, intervene
from chirho.interventional.handlers import do

In [30]:
@contextlib.contextmanager
def ExplainCauses(
    antecedents: Mapping[str, Intervention[T]]
    | Mapping[str, pyro.distributions.constraints.Constraint],
    witnesses: Mapping[str, Intervention[T]] | Iterable[str],
    consequents: Mapping[str, Callable[[T], float | torch.Tensor]]
    | Iterable[str],
    *,
    antecedent_bias: float = 0.0,
    witness_bias: float = 0.0,
    consequent_eps: float = -1e8,
    antecedent_prefix: str = "__antecedent_",
    witness_prefix: str = "__witness_",
    consequent_prefix: str = "__consequent_",
):
    """
    Effect handler for causal explanation.

    :param antecedents: A mapping from antecedent names to interventions.
    :param witnesses: A mapping from witness names to interventions.
    :param consequents: A mapping from consequent names to factor functions.
    """
    if isinstance(
        next(iter(antecedents.values())),
        pyro.distributions.constraints.Constraint,
    ):
        antecedents = {
            a: random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}")
            for a, s in antecedents.items()
        }

    if not isinstance(witnesses, collections.abc.Mapping):
        witnesses = {
            w: undo_split(antecedents=list(antecedents.keys()))
            for w in witnesses
        }

    if not isinstance(consequents, collections.abc.Mapping):
        consequents = {
            c: consequent_differs(
                antecedents=list(antecedents.keys()), eps=consequent_eps
            )
            for c in consequents
        }

    if len(consequents) == 0:
        raise ValueError("must have at least one consequent")

    if len(antecedents) == 0:
        raise ValueError("must have at least one antecedent")

    if set(consequents.keys()) & set(antecedents.keys()):
        raise ValueError(
            "consequents and possible antecedents must be disjoint"
        )

    if set(consequents.keys()) & set(witnesses.keys()):
        raise ValueError("consequents and possible witnesses must be disjoint")

    antecedent_handler = SearchForCause(
        actions=antecedents, bias=antecedent_bias, prefix=antecedent_prefix
    )
    witness_handler = Preemptions(
        actions=witnesses, bias=witness_bias, prefix=witness_prefix
    )
    consequent_handler = Factors(factors=consequents, prefix=consequent_prefix)

    with antecedent_handler, witness_handler, consequent_handler:
        with pyro.poutine.trace() as logging_tr:
            yield

In [31]:
def gather_observed(value, antecedents, witnesses):
    
    if isinstance(antecedents, dict):
        antecedents_list = list(antecedents.keys())
    else:
        antecedents_list = antecedents
        
    _indices = [
            i for i in antecedents_list + witnesses if i in indices_of(value, event_dim=0)
        ]
    _int_can = gather(
    value, IndexSet(**{i: {0} for i in _indices}), event_dim=0,)
    return _int_can

def gather_intervened(value, antecedents, witnesses):
    
    if isinstance(antecedents, dict):
        antecedents_list = list(antecedents.keys())
    else:
        antecedents_list = antecedents
        
        
    _indices = [
            i for i in antecedents_list + witnesses if i in indices_of(value, event_dim=0)
        ]
    _int_can = gather(
    value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
    return _int_can


def get_table(trace, mwc, antecedents, witnesses, consequents):

    values_table = {}
    nodes = trace.trace.nodes
    

    if isinstance(antecedents, dict):
        antecedents_list = list(antecedents.keys())
    else:
        antecedents_list = antecedents

    with mwc:

        for antecedent_str in antecedents_list:
                
            obs_ant = gather_observed(nodes[antecedent_str]["value"], antecedents_list, witnesses)
            int_ant = gather_intervened(nodes[antecedent_str]["value"], antecedents_list, witnesses)

            values_table[f"{antecedent_str}_obs"] = obs_ant.squeeze().tolist()
            values_table[f"{antecedent_str}_int"] = int_ant.squeeze().tolist()
            
            apr_ant = nodes[f"__antecedent_{antecedent_str}"]["value"]
            values_table[f"apr_{antecedent_str}"] = apr_ant.squeeze().tolist()
            
            values_table[f"apr_{antecedent_str}_lp"] = nodes[f"__antecedent_{antecedent_str}"]["fn"].log_prob(apr_ant)

        if witnesses:
            for candidate in witnesses:
                obs_candidate = gather_observed(nodes[candidate]["value"], antecedents_list, witnesses)
                int_candidate = gather_intervened(nodes[candidate]["value"], antecedents_list, witnesses)
                values_table[f"{candidate}_obs"] = obs_candidate.squeeze().tolist()
                values_table[f"{candidate}_int"] = int_candidate.squeeze().tolist()

                wpr_con = nodes[f"__witness_{candidate}"]["value"]
                values_table[f"wpr_{candidate}"] = wpr_con.squeeze().tolist()
            


        for consequent in consequents:
            
            obs_consequent = gather_observed(nodes[consequent]["value"], antecedents_list, witnesses)
            int_consequent = gather_intervened(nodes[consequent]["value"], antecedents_list, witnesses)
            con_lp = nodes[f"__consequent_{consequent}"]['fn'].log_prob(torch.tensor(1)) #TODO: this feels like a hack
            _indices_lp = [
            i for i in antecedents_list + witnesses if i in indices_of(con_lp)]
            int_con_lp = gather(con_lp, IndexSet(**{i: {1} for i in _indices_lp}), event_dim=0,)      


            values_table[f"{consequent}_obs"] = obs_consequent.squeeze().tolist()   
            values_table[f"{consequent}_int"] = int_consequent.squeeze().tolist()
            values_table[f"{consequent}_lp"] = int_con_lp.squeeze().tolist()   

    values_df = pd.DataFrame(values_table)

    values_df.drop_duplicates(inplace=True)

    summands = [col for col in values_df.columns if col.endswith('lp')]
    values_df["sum_log_prob"] =  values_df[summands].sum(axis = 1)
    values_df.sort_values(by = "sum_log_prob", ascending = False, inplace = True)

    return values_df
    


In [32]:
# this reduces the actual causality check to checking a property of the resulting sums of log probabilities
# for the antecedent preemption and the consequent differs nodes

def ac_check(trace, mwc, antecedents, witnesses, consequents):

     table = get_table(trace, mwc, antecedents, witnesses, consequents)
     
     if (list(table['sum_log_prob'])[0]<= -1e8):
          print("No resulting difference to the consequent in the sample.")
          return
     
     winners = table[table['sum_log_prob'] == table['sum_log_prob'].max()]
     

     ac_flags = []
     for index, row in winners.iterrows():
          active_antecedents = []
          for antecedent in antecedents:
               if row[f"apr_{antecedent}"] == 0:
                    active_antecedents.append(antecedent)

          ac_flags.append(set(active_antecedents) == set(antecedents))

     if not any(ac_flags):
          print("The antecedent set is not minimal.")
     else:
          print("The antecedent set is an actual cause.")

     return any(ac_flags)

In [33]:
def minimal_sets(set_list):
    inclusion_minimal = []
    for s1 in set_list:
        is_minimal = True
        for s2 in set_list:
            if s1 != s2 and s2.issubset(s1):
                is_minimal = False
        if is_minimal:
            inclusion_minimal.append(s1)
    return inclusion_minimal


def tensorize_dictionary(dictionary):
    return {k: torch.as_tensor(v) for k, v in dictionary.items()}

def boolean_constraints_from_list(list):
    return {k: pyro.distributions.constraints.boolean for k in list}


In [34]:

def sufficient_causality_checkA(table, considered_antecedent_setting, causal_candidates):

    sufficiency_table = table.copy()
        
    for antecedent_str in considered_antecedent_setting.keys():
        sufficiency_table = sufficiency_table[sufficiency_table[f"{antecedent_str}_obs"] == considered_antecedent_setting[antecedent_str]]
    
    sufficiency_table = sufficiency_table[sufficiency_table['sum_log_prob'] > -1e8]
    
    # we need to check set inclusion minimality of cause sets
    # as there might be inclusion minimal sets that are not log-prob-sum minimal
    # just because they have more elements
    candidate_sets = []
    for i, row in sufficiency_table.iterrows():
        candidate_set = set()
        for node in causal_candidates:
            if row[f"{node}_int"] != row[f"{node}_obs"]:
                candidate_set.add(node)
        candidate_sets.append(candidate_set)
        
        actual_cause_sets = minimal_sets(candidate_sets)
        
        frozensets = [frozenset(s) for s in actual_cause_sets]
        unique_actual_cause_sets = [set(f) for f in set(frozensets)]
        
    sufficiency_flag = any(key in ac_set for key in considered_antecedent_setting.keys() for ac_set in actual_cause_sets)

    return  unique_actual_cause_sets, sufficiency_flag

# Examples

## Forest fire example

In [35]:
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", match_dropped.bool() & lightning.bool(),
                                     event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}
    
def ff_disjunctive():
        u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
        u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

        match_dropped = pyro.deterministic("match_dropped",
                                        u_match_dropped.bool(), event_dim=0)
        lightning = pyro.deterministic("lightning", u_lightning.bool(), event_dim=0)
        forest_fire = pyro.deterministic("forest_fire", (match_dropped.bool() | lightning.bool()).bool(), event_dim=0).bool()

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}


ff_conjunctive()

{'match_dropped': tensor(0.),
 'lightning': tensor(0.),
 'forest_fire': tensor(0.)}

In [36]:
# Example 7.1.2. from the book

# all contexts available, no settings excluded by what 
# the agent knows about the world
# in the conjunctive model, the joint nodes are an explanation of forest fire

# these are explanation candidates
antecedents = {"match_dropped": 1.0, "lightning": 1.0}
consequents = ["forest_fire"]
consequents_observed = tensorize_dictionary({"forest_fire": 1.0})
all_nodes = ["match_dropped", "lightning", "forest_fire"]
causal_candidates = [node for node in all_nodes if node not in consequents]
causal_candidate_constraints = boolean_constraints_from_list(causal_candidates)


In [37]:
with MultiWorldCounterfactual() as mwc:
    with ExplainCauses(antecedents = causal_candidate_constraints, 
                      witnesses = causal_candidates, consequents = consequents):
            with condition(data = {"forest_fire": torch.tensor(True)}):
                with pyro.plate("sample", 100):
                    with pyro.poutine.trace() as tr:
                        ff_conjunctive()

ff_conjunctive_table =  get_table(tr, mwc, causal_candidates, causal_candidates, consequents)

condition_e1a = sufficient_causality_checkA(ff_conjunctive_table, antecedents, causal_candidates)

In [41]:
# now condition 1b 
# P(consequents | do(antecedents)) = 1

with MultiWorldCounterfactual() as mwc:
      #  with pyro.plate("sample", 100):
    with do( actions = antecedents):
        with pyro.plate("samples", 100):
                with pyro.poutine.trace() as tr:
                    ff_conjunctive()

outcome_df = pd.DataFrame()

with mwc:
    for consequent in consequents:
        
        value = tr.trace.nodes[consequent]["value"]
        _indices = [
                i for i in list(antecedents.keys()) if i in indices_of(value, event_dim=0)
            ]
        _int_con = gather(
        value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
        outcome_df[consequent] = _int_con.squeeze().tolist()
    
condition_e1b = ((outcome_df) == True).all().all()

print(condition_e1a and condition_e1b)


True


In [45]:

# condition_3: P(antecedents & consequent) > 0
# condition_4: P(antecedents < 1) 
    
with pyro.plate("samples", 100):
        with pyro.poutine.trace() as tr:
            ff_conjunctive()
            

reqs = {**antecedents, **consequents_observed}

reqs_outcome = pd.DataFrame()

for req in reqs:
    reqs_outcome[req] = tr.trace.nodes[req]["value"]
       
condition_e3 = (reqs_outcome == 1.0).all(axis=1).any()
condition_e4 = (reqs_outcome.iloc[:, :2] == 0.0).any(axis=1).any()

print(condition_e3)
print(condition_e4)
     

True
True
