# Explainable reasoning with ChiRho (categorical variables)


The **Explainable Reasoning with ChiRho** package aims to provide a systematic, unified approach to causal explanation computations in terms of different probabilistic queries over expanded causal models that are constructed from a single generic program transformation applied to an arbitrary causal model represented as a ChiRho program. The approach of reducing causal queries to probabilistic computations on transformed causal models is the foundational idea behind all of ChiRho. The key strategy underlying "causal explanation" queries is their use of auxiliary variables representing uncertainty over which interventions or preemptions to apply, implicitly inducing a search space over counterfactuals.

The goal of this notebook is to illustrate how the package can be used to provide approximate method of answering a range of causal explanation queries with respect to models in which the key role is played by categorical variables. Continuous variables are in the focus of another notebook.

In [another notebook](https://basisresearch.github.io/chirho/actual_causality.html) we illustrate how the module allows for a faithful reconstration of a specific notion of local explanation that inspired some of the conceptual moves underlying the current implementation  (the so-called Halpern-Pearl modified definition of actual causality [(J. Halpern, MIT Press, 2016)](https://mitpress.mit.edu/9780262537131/actual-causality/)).

**Outline**

[Introduction and motivations](#intuitions-and-motivations)
    
- [The but-for condition](#the-but-for-condtition)

- [Witness nodes and context-sensitivity](#witness-nodes-and-context-sensitivity)

[Simplified actual causality](#simplified-actual-causality)

[Probability of causation](#probability-of-causation)

[Causal explanation](#causal-explanation)

[Responsibility attribution](#responsibility-attribution)


In [1]:
import os

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

import pandas as pd

from chirho.observational.handlers import condition
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.explainable.handlers import SearchForExplanation
                                            
from chirho.indexed.ops import (IndexSet, gather, indices_of) 
from chirho.interventional.handlers import do


smoke_test = ('CI' in os.environ)
runs_n = 5 if smoke_test else 8000

## Introduction and motivations

Let's start with a model of a very simplistic situation, in which a forest fire can be cause by exactly one of two causes: a match being dropped (`match_dropped`), or a lightning (`lightning`), and either of these factors on its own is already deterministcally sufficient for the `forest_fire` to occur. In general, you think a match being dropped is more likely than a lightning (we use fairly large probabilities for the same of example transparency). 

In [2]:
def ff_disjunctive():
        match_dropped = pyro.sample("match_dropped", dist.Bernoulli(0.7)) # notice uneven probs here
        lightning = pyro.sample("lightning", dist.Bernoulli(0.4))

        forest_fire = pyro.deterministic("forest_fire", torch.max(match_dropped, lightning), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

# each run is stochastic
ff_disjunctive()

{'match_dropped': tensor(0.),
 'lightning': tensor(1.),
 'forest_fire': tensor(1.)}

Suppose in this particular case you know a forest fire has occured, a match has been dropped, but no lightning occured. This further assumption can be introduced by conditioning on these observations:

In [3]:
observations = {"match_dropped": torch.tensor(1.), 
                "lightning": torch.tensor(0.),
                "forest_fire": torch.tensor(1.)}

with condition(data = observations):
    with pyro.poutine.trace() as tr:
        ff_disjunctive()

# now it is determined how things play out
print({key: tr.trace.nodes[key]["value"] for key in tr.trace.nodes.keys()})


{'match_dropped': tensor(1.), 'lightning': tensor(0.), 'forest_fire': tensor(1.)}


In a particular context like this, when you know what happened, or at least know the values of some of the variables at play, you might be interested in using the model to answer a range of causal-explanation related questions.

- Did the match being dropped actually cause the fire?
- Suppose you only know that the forest fire occured, what are the likely explanations? How likely are they?
- Suppose you know both factors occured, to what extent should they be deemed responsible for this outcome?

Let's see how these can be addressed using ChiRho. 

### The but-for condition

The initial intuition that you might have is that in that situation, `match_dropped` is a cause of `forest_fire`, because had the match not been dropped, the forest fire would not have occurred, ro, in other words, there would be no forest fire but for the match being dropped. The `Search for Explanation` handler can be used to test for this condition, and gives the expected result in this case.

In [47]:
antecedents = {"match_dropped": 0.0}
witnesses = {}
consequents = {"forest_fire": constraints.boolean}

with MultiWorldCounterfactual() as mwc:  # needed to keep track of multiple scenarios
    with SearchForExplanation(antecedents = antecedents, 
                              witnesses = witnesses, # no witnesses, ignore for now
                              consequents = consequents,
                              consequent_scale= 1e-8):
        with condition(data = observations):
            with pyro.plate("sample", 10): # run a few times
                with pyro.poutine.trace() as tr:
                    ff_disjunctive()

print(tr.trace.nodes.keys())


odict_keys(['__antecedent_match_dropped', 'match_dropped', 'lightning', 'forest_fire_factual', 'forest_fire_counterfactual', '__consequent_forest_fire', 'forest_fire'])


Intuitively, we used `SearchForExplanation` to investigate what would have happened to the consequent(s) if we intervened on the antecedent(s) as specified, and we run the model a few times as now the run contains stochastic elements. The trace now contains more information.

1.  We now randomly intervened (for now, with uniform distribution) on `match_dropped` as specified in `antecedents`. `'__antecedent_match_dropped` now contains information on whether intervention has been preempted (that is, it has value `0` if the intervention wasn't blocked in a given run, and 1 if the intervention was blocked). 

In [None]:
antecedents = {"match_dropped": 0.0, 'lightning': 0.0}

In [48]:
nd = (tr.trace.nodes)
print("__antecedent_match_dropped:", nd["__antecedent_match_dropped"]["value"])

__antecedent_match_dropped: tensor([0, 0, 0, 1, 1, 0, 1, 1, 0, 1])


At this point you might think that randomly prempting the intervention from happening is unnecessary complication, and in this particular case it in fact is. The functionality is however useful in general for searching for possible antecedent sets. For this simple example we could supress this by shifting the uniform preemption probability down by .5. We'll just shift it down by .1 to be able to illustrate a point about log probabilities later on.

In [49]:
with MultiWorldCounterfactual() as mwc:  
    with SearchForExplanation(antecedents = antecedents, 
                              antecedent_bias= -.1, # we drop the probability of preemption
                              witnesses = witnesses, 
                              consequents = consequents,
                              consequent_scale= 1e-8):
        with condition(data = observations):
            with pyro.plate("sample", 10): 
                with pyro.poutine.trace() as tr:
                    ff_disjunctive()

tr.trace.compute_log_prob()
nd = (tr.trace.nodes)

print("__antecedent_match_dropped:", nd["__antecedent_match_dropped"]["value"])

__antecedent_match_dropped: tensor([1, 0, 0, 0, 0, 0, 1, 1, 0, 0])


2. `match_dropped` and `forest_fire` now contain values for the factual and the counterfactual scenario:

In [38]:
with mwc: # use the same mwc context to keep track of what's what
    print( indices_of(nd["match_dropped"]["value"]))  # each potential upstream intervention extends the indices
    # use the indices to pick the right values
    antecedent_factual = gather(nd["match_dropped"]["value"], IndexSet(**{'match_dropped': {0}}))
    antecedent_counterfactual = gather(nd["match_dropped"]["value"], IndexSet(**{'match_dropped': {1}}))

    consequent_factual = gather(nd["forest_fire"]["value"], IndexSet(**{'match_dropped': {0}}))
    consequent_counterfactual = gather(nd["forest_fire"]["value"], IndexSet(**{'match_dropped': {1}}))


print("Antecedent Factual:\n", antecedent_factual)
print("Antecedent Counterfactual:\n", antecedent_counterfactual)
print("Consequent Factual:\n", consequent_factual)
print("Consequent Counterfactual:\n", consequent_counterfactual)

IndexSet({'match_dropped': {0, 1}})
Antecedent Factual:
 tensor([[[[[1., 1., 1., 1., 1.]]]]])
Antecedent Counterfactual:
 tensor([[[[[0., 0., 0., 0., 0.]]]]])
Consequent Factual:
 tensor([[[[[1., 1., 1., 1., 1.]]]]])
Consequent Counterfactual:
 tensor([[[[[0., 0., 0., 0., 0.]]]]])


3. While we already see that here the answer is positive, the conterfactual value of the consequent would be different, to handle more general cases, `__consequent_forest_fire` records the score assigned to whether the factual and counterfactual values of the consequent differ. We used `consequent_scale= 1e-8` which in the binary case results in `log_prob = 0` for cases in which there is a difference and to `-inf` in cases in which there isn't.

In [51]:
with mwc: 
    print(gather(nd['__consequent_forest_fire']['log_prob'], 
                IndexSet(**{'match_dropped': {1}})))

tensor([[[[[-inf, 0., 0., 0., 0., 0., -inf, -inf, 0., 0.]]]]])


The problem with the but-for analysis of causality, though, is that it misdiagnoses causal factors in cases that involve overdetermination. If, for instance, we ask the same question in the context in which both a match has been dropped and a lightning took place, the answer will be negative, as strictly speaking preventing the match from being dropped wouldn't have prevented the forest fire. This is a misdiagnosis, as we still would like to think that the match being dropped played a causal role.


In [54]:
observations = {"match_dropped": torch.tensor(1.),
                "lightning": torch.tensor(1.),  # we changed this line 
                "forest_fire": torch.tensor(1.)}

with MultiWorldCounterfactual() as mwc:  
    with SearchForExplanation(antecedents = antecedents, 
                              witnesses = witnesses, 
                              consequents = consequents,
                              consequent_scale= 1e-8):
        with condition(data = observations):
            with pyro.plate("sample", 10):
                with pyro.poutine.trace() as tr:
                    ff_disjunctive()

tr.trace.compute_log_prob()
nd = tr.trace.nodes
with mwc: 
    print(gather(nd['__consequent_forest_fire']['log_prob'], 
                IndexSet(**{'match_dropped': {1}})))

tensor([[[[[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]]]]])


### Witness nodes and context-sensitivity

Some of these intuitions in the forest fire example can be perhaps salaved by considering a two-membered antecedent set:

In [65]:
antecedents = {"match_dropped": 0.0, 'lightning': 0.0}

with MultiWorldCounterfactual() as mwc:  
    with SearchForExplanation(antecedents = antecedents, 
                              antecedent_bias= -.5, # enforce execution of the intevention
                              witnesses = witnesses, 
                              consequents = consequents,
                              consequent_scale= 1e-8):
        with condition(data = observations):
            with pyro.plate("sample", 10):
                with pyro.poutine.trace() as tr:
                    ff_disjunctive()

tr.trace.compute_log_prob()
nd = tr.trace.nodes
with mwc: 
    print(gather(nd['__consequent_forest_fire']['log_prob'], 
                IndexSet(**{'match_dropped': {1}, "lightning": {1}})))  
                # note we needed to add the index for lightning

tensor([[[[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]]])


This already suggests a more complicated picture, as it turns out we need to pay attention to membership in larger antecedent sets that would make a difference (that is one reason why we need stochasticity in antecedent candidate preemption: to search for such subsets).

But even then, the but-for analysis does not pay sufficient attention to the granularity of cause sets and to actual contexts. There are assymetric cases where the efficiency of one cause prevents the efficiency of another, in which our causal attributions should also be assymetric. 

A simple example involves bottle shattering. Suppose Sally and Bob throw a stone at a bottle, Sally does so a bit earlier than Bob. Suppose both are perfectly accurate and the bottle shatters if hit. Sally hits, the bottle 
shatters, but Bob fails to hit, because the bottle isn’t there anymore. 

Sally's throw does not satisfy the but-for condition: if she didn't throw the stone the bottle would still have shattered. Of course, the combined event of Sally throwing a stone and Bob throwing a stone is a but-for cause of the bottle shattering, so in some sense we can use the but-for clause to identify the whole set as a cause. But this doesn't capture the clear assymetry involved here. Intuitively, Sally’s throw is the (actual) cause of the bottle shattering in a sense in which Bob's throw isn't.  Sally’s throw actually caused the bottle to shatter and Bob's didn't, partially because Bob’s stone actually failed to hit it.

An intuitive solution to the problem, inspired by the  Pearl-Halpern definition of actual causality (which we discuss in [another notebook](https://basisresearch.github.io/chirho/actual_causality.html)) is to say that in answering actual causality queries, we need to consider what happens when part of the actual context is kept fixed. For instance, in the bottle shattering example, given the observed fact that Bob’s stone didn’t hit, in the counterfactual world in which we keep this observed fact fixed, if Sally did not throw the stone, the bottle in fact would not have shattered. 

Thus, `SearchForCauses` not only allows for stochastic preemption of interventions (to approximate search through possible antecedent sets), but also stochastic witness-preeption of those nodes that are considered part of the context (these don't need to exclude each other). In a witness-preemption we ensure that the counterfactual value is identical to the factual one (and by applying it randomly to wintess node candidates we approximate search through all possible context sets). Let's define the model and apply the handler before we go through what the trace now contains.

In [69]:
def stones_model():        
    prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
    prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
    prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
    prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
    prob_bottle_shatters_if_sally = pyro.sample("prob_bottle_shatters_if_sally", dist.Beta(1, 1))
    prob_bottle_shatters_if_bill = pyro.sample("prob_bottle_shatters_if_bill", dist.Beta(1, 1))

    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))


    new_shp = torch.where(sally_throws == 1,prob_sally_hits, 0.0)

    sally_hits = pyro.sample("sally_hits",dist.Bernoulli(new_shp))

    new_bhp = torch.where(
        bill_throws.bool() & (~sally_hits.bool()),
        prob_bill_hits,
        torch.tensor(0.0),
    )

    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))

    new_bsp = torch.where(bill_hits.bool(), prob_bottle_shatters_if_bill,
            torch.where(sally_hits.bool(),prob_bottle_shatters_if_sally,torch.tensor(0.0),),)

    bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp))

    return {"sally_throws": sally_throws, "bill_throws": bill_throws,  "sally_hits": sally_hits,
            "bill_hits": bill_hits,  "bottle_shatters": bottle_shatters,}

stones_model.nodes = ["sally_throws","bill_throws", "sally_hits", "bill_hits","bottle_shatters",]

def tensorize_observations(observations):
    return {k: torch.as_tensor(v) for k, v in observations.items()}

# for now, we assume the mechanisms are deterministic
# and that both sally and bill throw stones
observations = {"prob_sally_throws": 1.0, 
                "prob_bill_throws": 1.0,
                "prob_sally_hits": 1.0,
                "prob_bill_hits": 1.0,
                "prob_bottle_shatters_if_sally": 1.0,
                "prob_bottle_shatters_if_bill": 1.0,
                "sally_throws": 1.0, "bill_throws": 1.0}

observations_tensorized = tensorize_observations(observations)

# instead of directly specifying an alternative scenario for the antecedent
# we can pass a constraint and at each run
# an intervened value is proposed automatically
# by sampling from an appropriate distribution

antecedents = {"sally_hits": constraints.boolean}
antencedent_bias = 0.1
witnesses = {"bill_throws": constraints.boolean, "bill_hits": constraints.boolean}
consequents = {"bottle_shatters": constraints.boolean}

with MultiWorldCounterfactual() as mwc:
    with SearchForExplanation(antecedents = antecedents, 
                       witnesses = witnesses, consequents = consequents,
                       consequent_scale= 1e-8):
        with condition(data = observations_tensorized):
            with pyro.plate("sample", 100):
                with pyro.poutine.trace() as tr:
                    stones_model()

## Simplified actual causality

## Probability of causation

## Causal explanation

## Responsibility attribution
