# Actual Causality: the modified Halpern-Pearl definition

**Summary**

Here we show how the tools made available within Causal Pyro  TODO: CHANGE NAME(?) can be used to implement the notion of actual causality developed by Halpern and Pearl (see J. Halpern, *Actual Causality*, 2016), and illustrate its workings by replicating a few key examples from the book.

**Outline**

[Intuitions](##intuitions)
    
[Formalization](#formalization)

- [Structural causal models](#structural-causal-models)

- [Halpern-Pearl modified definition of actual causality](#halpern-pearl-modified-definition-of-actual-causality)

[Implementation](#implementation)

[Examples](#examples)

- [Comments on example selection](#comments-on-example-selection)
  
- [Stone-throwing](#stone-throwing)

- [Forest fire](#forest-fire)

- [Doctors](#doctors)

- [Friendly fire](#friendly-fire)





## Intuitions

Actual causality (sometimes called **token causality** or **specific causality**) is usually contrasted with type causality (sometimes called **general causality**). While the latter is concerned with general statements (such as "smoking causes cancer"), actual causality focuses on particular events. For illustration, consider the following causality-related questions:

- **Friendly Fire**: On March 24, 2002, A B-52 bomber fired a Joint Direct Attack Munition at a US battalion command post, killing three and injuring twenty special forces soldiers. Out of multiple potential contributing factors, which were **actually** responsible for the incident?
  
- **Schizophrenia** : The disease arises from the interaction between multiple genetic and environmental factors. Given a particular patient and what we know about them, which of these factors **actually** caused her state?
  
- **Explainable AI**: Your loan application has been refused. The bank representative informs you the decision was made using predictive modeling to estimate the probability of default. They give you a list of various factors considered in the prediction. But which of these factors **actually** resulted in the rejection, and what were their contributions?
  
These are questions about **actual causality**. While having answers to such questions is not directly useful for prediction tasks, they are useful for understanding how we can prevent undesirable outcomes similar to ones that we have observed or promote the occurrence of desirable outcomes in contexts similar to the ones in which they had been observed. These context-sensitive causality questions are also an essential element of blame and responsibility assignments, and of at least one prominent account of the notion of explanation (all of which will be explored in other notebooks). TODO add links

The general intuition behind the notion of actual causality that we will focus on is that a certain state of antecedent nodes is the cause of a given state of the consequent nodes if there is a part of the actual reality such that if it is kept fixed at what it actually is, and we intervened on the antecedent nodes to be in a different state, the consequent nodes would no longer be in the observed states. A proper explication of this notion requires the context of structural causal models -  we first explain what these are, and then move on to the definition.

## Formalization 


### Structural causal models

While statistical information might help address questions of actual causality, is not sufficient.  One requires causal theories that explain how the relevant aspects of the world function, as well as information about the actual facts pertaining to the specific case. For this reason, the notion on which we focus in this notebook is formulated within the framework of structural causal models, which can represent such information.

The notion is defined in the context of a deterministic structural causal model (SCMs). One major component thereof is a selection of **variables**. For instance, in a very simple model for a forest-fire problem, we might consider a model with three endogenous binary variables: $FF$ (forest fire), $L$ (lightning), and $MD$ (match dropped) whose values are determined by the values of other variables, and two exogenous noise variables $U_{MD}$ and $U_L$ that determine the values of $MD$ and $L$. Moreover, some of those variables/nodes are connected by means of directed **edges**. For instance, in the example at hand, the model contains two edges that go from $U_MD$ to $MD$ and from $U_L$ to $L$ respectively, and two edges that go from $L$ to $FF$ and from $MD$ to $FF$. Each influence is associated with a **structural equation** - for instance, $FF = max(L, MD)$ indicates that a forest fire occurs if either of the two factors occurs. SCMs come also with a **context**, which is the values of **exogenous variables** whose values are not determined by the structural equations, but rather by factors outside the model. In our example, one context might be that both a match has been dropped and a lightning occurred.

More formally, a causal model $M$ is a tuple $\langle S, F\rangle$, where:

- $S$ is a **signature**, that is a tuple $\langle U, V, R\rangle$, where $U$ is a set of exogenous variables, $V$ is a set of endogenous variables and $R: U \cup V \mapsto R(Y)$, where $R(Y)\neq \emptyset$, that is $R$ assigns non-empty ranges to exogenous and endogenous variables.

- To each endogenous $X\in V$, $F$ assigns a function $F_X$, which maps the cross-product of ranges of all variables other than $X$ to $R(X)$. In other words, $F_X$ determines the value of $X$ given the values of other variables in the model (some of them might be redundant in a given equation). The intuition is that these functions correspond to structural equations of the form $X = F_X(U, V)$ which are to be read from right to left: if the values of $U\cup V$ are fixed to be such-and-such, say $\vec{u}$ and $\vec{v}$, this causes $X$ to take the value $F_X(\vec{u}, \vec{v})$.

A **deterministic causal model** (also called **causal setting**), $\langle M, \vec{u}\rangle$ is a causal model $M$ together with fixed settings $\vec{u}$ of its exogenous variables $U$. To intervene, say, to make $Y$ have value $y$, is to replace the structural equation for $Y$ of the form $Y = F_Y(U, V)$ with $Y = y$. $\langle M, \vec{u}\rangle \models [Y \leftarrow y](X = x)$ means: in the deterministic model obtained from $\langle M, \vec{u}\rangle$ by intervening on $Y$ to have value $y$ $X$ has value $x$. Sometimes, instead of $X = x$, one might be interested in a more general claim $\varphi$ involving potentially multiple variables, in which case the notation is $\langle M, \vec{u}\rangle \models [Y \leftarrow y](\varphi)$. 

## Halpern-Pearl modified definition of actual causality

It is important to recognize that the straightforward counterfactual strategy, which asks whether the event would have occurred if the antecedent had not taken place, is inadequate as a definition of actual causality. A simple example can help illustrate this point. Suppose I throw a stone, which hits and shatters a bottle. However, just a second later, Bill also throws a stone at the bottle but misses, solely because the bottle was already shattered by my stone. In this scenario, the intuition is that my throw is the cause of the bottle shattering, even though the bottle would still have shattered if I hadn't thrown the stone. 
This highlights the need for a more elaborate account that considers the actual state, taking into consideration the fact that Bill's stone did not, in fact, hit the bottle. One such account involves the following definition of actual causality:

Given an SCM $M$ and a vector of its exogenous variable settings $\vec{u}$ we'll write $(M, \vec{u})\models [ \vec{Y} \leftarrow \vec{y}]\psi$ just in case $\psi$ holds in $(M',\vec{u})$, where $M'$ is the intervened model obtained by replacing the structural equation(s) for $\vec{Y}$ in $M$ with $\vec{Y_i} = \vec{y_i}$. 

We say that $\vec{X}=\vec{x}$ is an actual cause of $\varphi$ in $(M,\vec{u})$ just in case:

AC1. Factivity: $(M, \vec{u}) \models [\vec{X} = \vec{x} \wedge \varphi]$

AC2. Necessity:

$\exists \vec{W}, \vec{x}'(M, \vec{u})\models [\vec{X} \leftarrow \vec{x}', \vec{W} = \vec{w}^{\star}]   \neg \varphi$,
where $\vec{w}^\star$ are the actual values of $\vec{W}$, i.e. $(M, \vec{u}) \models \vec{W} = \vec{w}^\star$

AC3. Minimality: $\vec{X}$ is a subset-minimal set of potential causes satisfying AC2

AC1 requires that both the antecedent and the consequent hold. The intuition behind AC2 is that for $\vec{X}=\vec{x}$ to be the actual cause of $\varphi$, there needs to be a vector of witness nodes $\vec{W}$ and a vector $\vec{x'}$ of *alternative* settings of $\vec{X}$ such that if $\vec{W}$ are intervened to have their actual values $\vec{w^\star}$, and $\vec{X}$ are intervened to have values $\vec{x'}$, $\varphi$ no longer holds in the resulting model. AC3 requires that the antecedent should be a minimal one satisfying AC2.

## Implementation


In [2]:
import functools
import contextlib
import collections
from typing import Callable, Iterable, TypeVar, Mapping, List, Dict

import pyro
import torch  # noqa: F401

from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")

import pyro
import chirho
import pyro.distributions as dist
import pyro.infer
import torch
import pandas as pd

from chirho.counterfactual.handlers.counterfactual import (MultiWorldCounterfactual,
        Preemptions)
from chirho.counterfactual.handlers.explanation import (
    SearchForCause,
    consequent_differs,
    random_intervention,
    undo_split,
    uniform_proposal,
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors, condition
from chirho.interventional.ops import Intervention, intervene
from chirho.interventional.handlers import do

In [3]:
@contextlib.contextmanager
def ExplainCauses(
    antecedents: Mapping[str, Intervention[T]]
    | Mapping[str, pyro.distributions.constraints.Constraint],
    witnesses: Mapping[str, Intervention[T]] | Iterable[str],
    consequents: Mapping[str, Callable[[T], float | torch.Tensor]]
    | Iterable[str],
    *,
    antecedent_bias: float = 0.0,
    witness_bias: float = 0.0,
    consequent_eps: float = -1e8,
    antecedent_prefix: str = "__antecedent_",
    witness_prefix: str = "__witness_",
    consequent_prefix: str = "__consequent_",
):
    """
    Effect handler for causal explanation.

    :param antecedents: A mapping from antecedent names to interventions.
    :param witnesses: A mapping from witness names to interventions.
    :param consequents: A mapping from consequent names to factor functions.
    """
    if isinstance(
        next(iter(antecedents.values())),
        pyro.distributions.constraints.Constraint,
    ):
        antecedents = {
            a: random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}")
            for a, s in antecedents.items()
        }

    if not isinstance(witnesses, collections.abc.Mapping):
        witnesses = {
            w: undo_split(antecedents=list(antecedents.keys()))
            for w in witnesses
        }

    if not isinstance(consequents, collections.abc.Mapping):
        consequents = {
            c: consequent_differs(
                antecedents=list(antecedents.keys()), eps=consequent_eps
            )
            for c in consequents
        }

    if len(consequents) == 0:
        raise ValueError("must have at least one consequent")

    if len(antecedents) == 0:
        raise ValueError("must have at least one antecedent")

    if set(consequents.keys()) & set(antecedents.keys()):
        raise ValueError(
            "consequents and possible antecedents must be disjoint"
        )

    if set(consequents.keys()) & set(witnesses.keys()):
        raise ValueError("consequents and possible witnesses must be disjoint")

    antecedent_handler = SearchForCause(
        actions=antecedents, bias=antecedent_bias, prefix=antecedent_prefix
    )
    witness_handler = Preemptions(
        actions=witnesses, bias=witness_bias, prefix=witness_prefix
    )
    consequent_handler = Factors(factors=consequents, prefix=consequent_prefix)

    with antecedent_handler, witness_handler, consequent_handler:
        with pyro.poutine.trace() as logging_tr:
            yield

Here and in later notebooks, instead of full enumeration, we will be approximating the answers with sampling. In particular, answering an actual causality query requires investigating the consequences of intervening on all possible witness candidate nodes in all possible combinations thereof to have the values they actually have in a given model. While complete enumeration would work for smaller models, we implement a more general approximate method, which draws random sets of witness nodes multiple times and intervenes on those sampled sets. For smaller models (as the one used in our examples), complete coverage of all possible combinations is easily obtained. For larger models complete enumeration becomes less feasible.

An SCM in this context is represented by a Pyro model, where the exogenous variables are stochastic and introduced using `pyro.sample`, and all the endogenous variables are determined by these, and introduced by `pyro.deterministic` (read on for examples). For simplicity we often assume most of the nodes are binary (this assumption can be weakened, read on for details), and that the nodes are discrete. 

The key role in this implementation is played by (1) the `SearchOfCause` handler. It takes `antecedents`, `witnesses`, `consequents`, `antecedent_bias`  and `witness_bias` and roughly makes three steps:

(A) It randomly intervenes on some of the antecedents (each antecedent node having probability `0.5 - bias` of being intervened on, with non-null bias to prefer smaller antedecedent sets) to have an alternative value (either pre-specified, or randomly selected).

(B) randomly preempts some of the witnesses intervening on them to have the observed value in all counterfactual worlds (the probability of witness preemption is `0.5 + witness_bias`).

(C) adds a site with `log_probs` tracking whether the coutnerfactual value of any of the consequents is different from its observed value,  marking cases where it doesn't with an extremely low `log_prob` of `1e-8` (and 0 otherwise). 

Since those steps are achieved by adding new sites to the model, the model trace can now be inspected to test for actual causality. In particular, if the `log_prob` of the site added in (C) is 0, then the antecedent is definitely not an actual cause of the consequent. If it is non-zero, minimality claims are evaluated by investigating the `log_prob_sum` corresponding to the antecedent preemption sites. All in all, an antecedent set is an actual cause if all its nodes and only its nodes are intervened on in the MAP counterfactual world.

In [4]:
def gather_observed(value, antecedents, witnesses):
    _indices = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(value, event_dim=0)
        ]
    _int_can = gather(
    value, IndexSet(**{i: {0} for i in _indices}), event_dim=0,)
    return _int_can

def gather_intervened(value, antecedents, witnesses):
    _indices = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(value, event_dim=0)
        ]
    _int_can = gather(
    value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
    return _int_can


def get_table(trace, mwc, antecedents, witnesses, consequents):

    values_table = {}
    nodes = trace.trace.nodes

    with mwc:

        for antecedent_str in antecedents.keys():
                
            obs_ant = gather_observed(nodes[antecedent_str]["value"], antecedents, witnesses)
            int_ant = gather_intervened(nodes[antecedent_str]["value"], antecedents, witnesses)

            values_table[f"{antecedent_str}_obs"] = obs_ant.squeeze().tolist()
            values_table[f"{antecedent_str}_int"] = int_ant.squeeze().tolist()
            
            apr_ant = nodes[f"__antecedent_{antecedent_str}"]["value"]
            values_table[f"apr_{antecedent_str}"] = apr_ant.squeeze().tolist()
            
            values_table[f"apr_{antecedent_str}_lp"] = nodes[f"__antecedent_{antecedent_str}"]["fn"].log_prob(apr_ant)

        if witnesses:
            for candidate in witnesses:
                obs_candidate = gather_observed(nodes[candidate]["value"], antecedents, witnesses)
                int_candidate = gather_intervened(nodes[candidate]["value"], antecedents, witnesses)
                values_table[f"{candidate}_obs"] = obs_candidate.squeeze().tolist()
                values_table[f"{candidate}_int"] = int_candidate.squeeze().tolist()

                wpr_con = nodes[f"__witness_{candidate}"]["value"]
                values_table[f"wpr_{candidate}"] = wpr_con.squeeze().tolist()
            

        for consequent in consequents:
            obs_consequent = gather_observed(nodes[consequent]["value"], antecedents, witnesses)
            int_consequent = gather_intervened(nodes[consequent]["value"], antecedents, witnesses)
            con_lp = nodes[f"__consequent_{consequent}"]['fn'].log_prob(torch.tensor(1)) #TODO: this feels like a hack
            _indices_lp = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(con_lp)]
            int_con_lp = gather(con_lp, IndexSet(**{i: {1} for i in _indices_lp}), event_dim=0,)      


            values_table[f"{consequent}_obs"] = obs_consequent.squeeze().tolist()   
            values_table[f"{consequent}_int"] = int_consequent.squeeze().tolist()
            values_table[f"{consequent}_lp"] = int_con_lp.squeeze().tolist()   

    values_df = pd.DataFrame(values_table)

    values_df.drop_duplicates(inplace=True)

    summands = [col for col in values_df.columns if col.endswith('lp')]
    values_df["sum_log_prob"] =  values_df[summands].sum(axis = 1)
    values_df.sort_values(by = "sum_log_prob", ascending = False, inplace = True)

    return values_df

In [5]:
# this reduces the actual causality check to checking a property of the resulting sums of log probabilities
# for the antecedent preemption and the consequent differs nodes

def ac_check(trace, mwc, antecedents, witnesses, consequents):

     table = get_table(trace, mwc, antecedents, witnesses, consequents)
     
     if (list(table['sum_log_prob'])[0]<= -1e8):
          print("No resulting difference to the consequent in the sample.")
          return
     
     winners = table[table['sum_log_prob'] == table['sum_log_prob'].max()]
     

     ac_flags = []
     for index, row in winners.iterrows():
          active_antecedents = []
          for antecedent in antecedents:
               if row[f"apr_{antecedent}"] == 0:
                    active_antecedents.append(antecedent)

          ac_flags.append(set(active_antecedents) == set(antecedents))

     if not any(ac_flags):
          print("The antecedent set is not minimal.")
     else:
          print("The antecedent set is an actual cause.")

     return any(ac_flags)

### Stone-throwing

In [5]:
@pyro.infer.config_enumerate
def stones_model():        
    prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
    prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
    prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
    prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
    prob_bottle_shatters_if_sally = pyro.sample("prob_bottle_shatters_if_sally", dist.Beta(1, 1))
    prob_bottle_shatters_if_bill = pyro.sample("prob_bottle_shatters_if_bill", dist.Beta(1, 1))

    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))


    new_shp = torch.where(sally_throws == 1,prob_sally_hits, 0.0)

    sally_hits = pyro.sample("sally_hits",dist.Bernoulli(new_shp))

    new_bhp = torch.where(
        bill_throws.bool() & (~sally_hits.bool()),
        prob_bill_hits,
        torch.tensor(0.0),
    )

    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))
    

    new_bsp = torch.where(bill_hits.bool(), prob_bottle_shatters_if_bill,
            torch.where(sally_hits.bool(),prob_bottle_shatters_if_sally,torch.tensor(0.0),),)

    bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp))

    return {"sally_throws": sally_throws, "bill_throws": bill_throws,  "sally_hits": sally_hits,
            "bill_hits": bill_hits,  "bottle_shatters": bottle_shatters,}

stones_model.nodes = ["sally_throws","bill_throws", "sally_hits", "bill_hits","bottle_shatters",]

In [6]:
def tensorize_observations(observations):
    return {k: torch.as_tensor(v) for k, v in observations.items()}

observations = {"prob_sally_throws": 1.0, 
                "prob_bill_throws": 1.0,
                "prob_sally_hits": 1.0,
                "prob_bill_hits": 1.0,
                "prob_bottle_shatters_if_sally": 1.0,
                "prob_bottle_shatters_if_bill": 1.0,
                "sally_throws": 1.0, "bill_throws": 1.0}

observations_tensorized = tensorize_observations(observations)

# One way to go is to manually specify a single alternative value
# which helps if you explicitly want to use a contrastive notion of 
# actual causality
antecedents = {"sally_throws": 0.0}
antencedent_bias = 0.1
witnesses = ["bill_throws", "bill_hits"]
consequents = ["bottle_shatters"]

#TODO? fails silently when consequents is a string

In [7]:
with MultiWorldCounterfactual() as mwc:
    with ExplainCauses(antecedents = antecedents, 
                       witnesses = witnesses, consequents = consequents):
        with condition(data = observations_tensorized):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr:
                    stones_model()


In [8]:
stones_table = get_table(tr, mwc, antecedents, witnesses, consequents)
display(stones_table)

Unnamed: 0,sally_throws_obs,sally_throws_int,apr_sally_throws,apr_sally_throws_lp,bill_throws_obs,bill_throws_int,wpr_bill_throws,bill_hits_obs,bill_hits_int,wpr_bill_hits,bottle_shatters_obs,bottle_shatters_int,bottle_shatters_lp,sum_log_prob
2,1.0,0.0,0,-0.693147,1.0,1.0,0,0.0,0.0,1,1.0,0.0,0.0,-0.6931472
14,1.0,0.0,0,-0.693147,1.0,1.0,1,0.0,0.0,1,1.0,0.0,0.0,-0.6931472
0,1.0,1.0,1,-0.693147,1.0,1.0,1,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
1,1.0,0.0,0,-0.693147,1.0,1.0,1,0.0,1.0,0,1.0,1.0,-100000000.0,-100000000.0
4,1.0,0.0,0,-0.693147,1.0,1.0,0,0.0,1.0,0,1.0,1.0,-100000000.0,-100000000.0
5,1.0,1.0,1,-0.693147,1.0,1.0,0,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
7,1.0,1.0,1,-0.693147,1.0,1.0,1,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
8,1.0,1.0,1,-0.693147,1.0,1.0,0,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0


In [9]:
ac_check(tr, mwc, antecedents, witnesses, consequents)

The antecedent set is an actual cause.


True

In [10]:
# If, more in the spirit of the original definition
# we want to search through all possible values of the antecedent,
# we can use a constraint  instead, instead of specifying the counterfactual value
# manually

antecedents = {"sally_hits": pyro.distributions.constraints.boolean}


with MultiWorldCounterfactual() as mwc:
    with ExplainCauses(antecedents = antecedents, 
                       witnesses = witnesses, consequents = consequents):
        with condition(data = observations_tensorized):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr:
                    stones_model()
                    
                

In [13]:
# now our samples include some cases where the antecedent intervention
# was with the observed value; this does not change the result,
# as the __consequent_ log prob is practically -inf in these cases

stones_table = get_table(tr, mwc, antecedents, witnesses, consequents)
display(stones_table)

Unnamed: 0,sally_hits_obs,sally_hits_int,apr_sally_hits,apr_sally_hits_lp,bill_throws_obs,bill_throws_int,wpr_bill_throws,bill_hits_obs,bill_hits_int,wpr_bill_hits,bottle_shatters_obs,bottle_shatters_int,bottle_shatters_lp,sum_log_prob
3,1.0,0.0,0,-0.693147,1.0,1.0,0,0.0,0.0,1,1.0,0.0,0.0,-0.6931472
4,1.0,0.0,0,-0.693147,1.0,1.0,1,0.0,0.0,1,1.0,0.0,0.0,-0.6931472
0,1.0,0.0,0,-0.693147,1.0,1.0,0,0.0,1.0,0,1.0,1.0,-100000000.0,-100000000.0
1,1.0,1.0,1,-0.693147,1.0,1.0,0,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
2,1.0,1.0,1,-0.693147,1.0,1.0,0,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
5,1.0,1.0,1,-0.693147,1.0,1.0,1,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
6,1.0,1.0,0,-0.693147,1.0,1.0,1,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
7,1.0,1.0,0,-0.693147,1.0,1.0,0,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
10,1.0,1.0,0,-0.693147,1.0,1.0,0,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
13,1.0,1.0,1,-0.693147,1.0,1.0,1,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0


In [15]:
ac_check(tr, mwc, antecedents, witnesses, consequents)

# since we're dealing with binary antecedents in this notebook,
# we'll keep using the contrastive notion in what follows

The antecedent set is an actual cause.


True

In [16]:
# this antecedent set is not minimal
antecedents2 = {"sally_throws": 0.0, "bill_throws": 0.0}
witnesses2 = ["bill_hits"]  #we know 'bill_throws' is not downstream from any split

with MultiWorldCounterfactual() as mwc2:
    with ExplainCauses(antecedents = antecedents2, antecedent_bias= antencedent_bias,
                        witnesses = witnesses2,
                        consequents = consequents):
        with condition(data = observations_tensorized):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr2:
                    stones_model()

In [17]:
stones_table2 = get_table(tr2, mwc2, antecedents2, witnesses2, consequents)
display(stones_table2)

Unnamed: 0,sally_throws_obs,sally_throws_int,apr_sally_throws,apr_sally_throws_lp,bill_throws_obs,bill_throws_int,apr_bill_throws,apr_bill_throws_lp,bill_hits_obs,bill_hits_int,wpr_bill_hits,bottle_shatters_obs,bottle_shatters_int,bottle_shatters_lp,sum_log_prob
13,1.0,0.0,0,-0.916291,1.0,1.0,1,-0.510826,0.0,0.0,1,1.0,0.0,0.0,-1.427116
3,1.0,0.0,0,-0.916291,1.0,0.0,0,-0.916291,0.0,0.0,1,1.0,0.0,0.0,-1.832581
24,1.0,0.0,0,-0.916291,1.0,0.0,0,-0.916291,0.0,0.0,0,1.0,0.0,0.0,-1.832581
1,1.0,1.0,1,-0.510826,1.0,1.0,1,-0.510826,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
2,1.0,1.0,1,-0.510826,1.0,1.0,1,-0.510826,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
0,1.0,1.0,1,-0.510826,1.0,0.0,0,-0.916291,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
5,1.0,0.0,0,-0.916291,1.0,1.0,1,-0.510826,0.0,1.0,0,1.0,1.0,-100000000.0,-100000000.0
8,1.0,1.0,1,-0.510826,1.0,0.0,0,-0.916291,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0


In [18]:
ac_check(tr2, mwc2, antecedents2, witnesses2, consequents)

The antecedent set is not minimal.


False

### Forest fire

In [13]:
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", match_dropped.bool() & lightning.bool(),
                                     event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

In [14]:
def ff_disjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", match_dropped.bool() | lightning.bool(), event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}


In [15]:
antecedents_ff = {"match_dropped": 0.0}
witnesses_ff = ["lightning"]
consequents_ff = ["forest_fire"]
observations_ff = tensorize_observations({"match_dropped": 1.0, "lightning": 1.0})

In [16]:
with MultiWorldCounterfactual() as mwc_ff:
    with ExplainCauses(antecedents = antecedents_ff, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_ff,
                        consequents = consequents_ff):
        with condition(data = observations_ff):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_ff:
                    ff_conjunctive()

In [17]:
# In the conjunctive model 
# Each of the two factors is a but-for cause

ac_check(tr_ff, mwc_ff, antecedents_ff, witnesses_ff, consequents_ff)

The antecedent set is an actual cause.


True

In [18]:
# In the disjunctive model 
# there still would be fire if there was no lightning

with MultiWorldCounterfactual() as mwc_ffd:
    with ExplainCauses(antecedents = antecedents_ff, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_ff,
                        consequents = consequents_ff):
        with condition(data = observations_ff):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_ffd:
                    ff_disjunctive()

In [19]:
ac_check(tr_ffd, mwc_ffd, antecedents_ff, witnesses_ff, consequents_ff)

No resulting difference to the consequent in the sample.


In [20]:
# in the disjunctive model
# the actual cause is the composition of the two factors

antecedents_ffd2 = {"match_dropped": 0.0, "lightning":0.0}
witnesses_ffd2 = []

with MultiWorldCounterfactual() as mwc_ffd2:
    with ExplainCauses(antecedents = antecedents_ffd2, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_ffd2,
                        consequents = consequents_ff):
        with condition(data = observations_ff):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_ffd2:
                    ff_disjunctive()

In [21]:
ac_check(tr_ffd2, mwc_ffd2, antecedents_ffd2, witnesses_ffd2, consequents_ff)

The antecedent set is an actual cause.


True

### Doctors

In [22]:
def bc_function(mt, tt):
    condition1 = (mt == 1) & (tt == 1)
    condition2 = (mt == 1) & (tt == 0)
    condition3 = (mt == 0) & (tt == 1)
    condition4 = ~(condition1 | condition2 | condition3)

    output = torch.where(condition1, torch.tensor(3.0), torch.tensor(0.0))
    output = torch.where(condition2, torch.tensor(0.0), output)
    output = torch.where(condition3, torch.tensor(1.0), output)
    output = torch.where(condition4, torch.tensor(2.0), output)

    return output


def model_doctors():
    u_monday_treatment = pyro.sample("u_monday_treatment", dist.Bernoulli(0.5))

    monday_treatment = pyro.deterministic(
        "monday_treatment", u_monday_treatment, event_dim=0
    )

    tuesday_treatment = pyro.deterministic(
        "tuesday_treatment",
        torch.logical_not(monday_treatment).float(),
        event_dim=0,
    )

    bills_condition = pyro.deterministic(
        "bills_condition",
        bc_function(monday_treatment, tuesday_treatment),
        event_dim=0,
    )

    bill_alive = pyro.deterministic(
        "bill_alive", bills_condition.not_equal(3.0).float(), event_dim=0
    )

    return {
        "monday_treatment": monday_treatment,
        "tuesday_treatment": tuesday_treatment,
        "bills_condition": bills_condition,
        "bill_alive": bill_alive,
    }

In [23]:
antecedents_doc1 = {"monday_treatment": 0.0}
witnesses_doc = []
consequents_doc1 = ["tuesday_treatment"]
observations_doc = tensorize_observations({"u_monday_treatment": 1.0})

In [24]:
# The first actual causal link holds

with MultiWorldCounterfactual() as mwc_doc1:
    with ExplainCauses(antecedents = antecedents_doc1, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_doc,
                        consequents = consequents_doc1):
        with condition(data = observations_doc):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_doc1:
                    model_doctors()
                    
ac_check(tr_doc1, mwc_doc1, antecedents_doc1, witnesses_doc, consequents_doc1)

The antecedent set is an actual cause.


True

In [25]:
# So does the second

antecedents_doc2 = {"tuesday_treatment": 1.0}
consequents_doc2 = ["bill_alive"]


with MultiWorldCounterfactual() as mwc_doc2:
    with ExplainCauses(antecedents = antecedents_doc2, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_doc,
                        consequents = consequents_doc2):
        with condition(data = observations_doc):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_doc2:
                    model_doctors()

ac_check(tr_doc2, mwc_doc2, antecedents_doc2, witnesses_doc, consequents_doc2)

The antecedent set is an actual cause.


True

In [26]:
# The third does not, so transitivity fails!

with MultiWorldCounterfactual() as mwc_doc3:
    with ExplainCauses(antecedents = antecedents_doc1, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_doc,
                        consequents = consequents_doc2):
        with condition(data = observations_doc):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_doc3:
                    model_doctors()

ac_check(tr_doc3, mwc_doc3, antecedents_doc1, witnesses_doc, consequents_doc2)

No resulting difference to the consequent in the sample.


### Friendly fire


In [27]:

def model_friendly_fire():
    u_f4_PLGR_now = pyro.sample("u_f4_PLGR_now", dist.Bernoulli(0.5))
    u_f11_training = pyro.sample("u_f11_training", dist.Bernoulli(0.5))

    f4_PLGR_now = pyro.deterministic("f4_PLGR_now", u_f4_PLGR_now, event_dim=0)
    f11_training = pyro.deterministic(
        "f11_training", u_f11_training, event_dim=0
    )

    f6_PLGR_before = pyro.deterministic(
        "f6_PLGR_before", f4_PLGR_now, event_dim=0
    )
    f7_second_calculation = pyro.deterministic(
        "f7_second_calculation", f4_PLGR_now, event_dim=0
    )
    f13_battery_died = pyro.deterministic(
        "f13_battery_died",
        f6_PLGR_before.bool() & f7_second_calculation.bool(),
        event_dim=0,
    )

    f1_battery_change = pyro.deterministic(
        "f1_battery_change", f13_battery_died, event_dim=0
    )

    f12_PLGR_after = pyro.deterministic(
        "f12_PLGR_after", f1_battery_change, event_dim=0
    )

    f5_unaware = pyro.deterministic("f5_unaware", f11_training, event_dim=0)

    f14_wrong_position = pyro.deterministic(
        "f14_wrong_position", f5_unaware, event_dim=0
    )

    f9_mistake_call = pyro.deterministic(
        "f9_mistake_call",
            f12_PLGR_after.bool() & 
            f14_wrong_position.bool(),
        event_dim=0,
    )

    f3_fired = pyro.deterministic("f3_fired", f9_mistake_call, event_dim=0)

    f10_landed = pyro.deterministic(
        "f10_landed", f3_fired.bool() &  f9_mistake_call.bool(), event_dim=0
    )

    f2_killed = pyro.deterministic("f2_killed", f10_landed, event_dim=0)

    return {
        "f1_battery_change": f1_battery_change,
        "f2_killed": f2_killed,
        "f3_fired": f3_fired,
        "f4_PLGR_now": f4_PLGR_now,
        "f5_unaware": f5_unaware,
        "f6_PLGR_before": f6_PLGR_before,
        "f7_second_calculation": f7_second_calculation,
        "f9_mistake_call": f9_mistake_call,
        "f10_landed": f10_landed,
        "f11_training": f11_training,
        "f12_PLGR_after": f12_PLGR_after,
        "f13_battery_died": f13_battery_died,
        "f14_wrong_position": f14_wrong_position,
    }

In [28]:
antecedents_fi1 = {"f6_PLGR_before": 0.0, "f7_second_calculation": 0.0}
consequents_fi = ["f2_killed"]
witnesses_fi  = ["f4_PLGR_now","f5_unaware", "f11_training", "f14_wrong_position"]
observations_fi = tensorize_observations({"u_f4_PLGR_now": 1.0, "u_f11_training": 1.0})

In [29]:
with MultiWorldCounterfactual() as mwc_fi1:
    with ExplainCauses(antecedents = antecedents_fi1, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_fi,
                        consequents = consequents_fi):
        with condition(data = observations_fi):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_fi1:
                    model_friendly_fire()              

In [30]:
ac_check(tr_fi1, mwc_fi1, antecedents_fi1, witnesses_fi, consequents_fi)

The antecedent set is not minimal.


False

In [31]:
antecedents_fi2 = {"f6_PLGR_before": 0.0}

with MultiWorldCounterfactual() as mwc_fi2:
    with ExplainCauses(antecedents = antecedents_fi2, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_fi,
                        consequents = consequents_fi):
        with condition(data = observations_fi):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_fi2:
                    model_friendly_fire()    
                    
ac_check(tr_fi2, mwc_fi2, antecedents_fi2, witnesses_fi, consequents_fi)

The antecedent set is an actual cause.


True

### Voting


In [32]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)
    
    outcome = pyro.deterministic("outcome", vote0 + vote1 + vote2 + vote3 + vote4 + vote5 > 3).float()
    return {"outcome": outcome}


In [33]:
antecedents_v = {"vote0":0.0}
outcome_v = ["outcome"]
witnesses_v = [f"vote{i}" for i in range(1,6)]
observations_v1 = tensorize_observations(dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=0., u_vote5=0.))

In [34]:
with MultiWorldCounterfactual() as mwc_v1:
    with ExplainCauses(antecedents = antecedents_v, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_v,
                        consequents = outcome_v):
        with condition(data = observations_v1):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_v1:
                    voting_model()

# if you're one of four voters who voted for, you are an actual cause
# of the outcome

ac_check(tr_v1, mwc_v1, antecedents_v, witnesses_v, outcome_v)

The antecedent set is an actual cause.


True

In [36]:
# if you're one of five voters who voted for, you are not an actual cause
# of the outcome

observations_v2 = tensorize_observations(dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote5=0.))

with MultiWorldCounterfactual() as mwc_v2:
    with ExplainCauses(antecedents = antecedents_v, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_v,
                        consequents = outcome_v):
        with condition(data = observations_v2):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_v2:
                    voting_model()
                    
ac_check(tr_v2, mwc_v2, antecedents_v, witnesses_v, outcome_v)

No resulting difference to the consequent in the sample.


In [37]:
antecedents_v3 = {"vote0":0.0, "vote1": 0.0}
witnesses_v3 = [f"vote{i}" for i in range(2,6)]

with MultiWorldCounterfactual() as mwc_v3:
    with ExplainCauses(antecedents = antecedents_v3, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_v3,
                        consequents = outcome_v):
        with condition(data = observations_v2):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_v3:
                    voting_model()

ac_check(tr_v3, mwc_v3, antecedents_v3, witnesses_v3, outcome_v)

The antecedent set is an actual cause.


True