# Responsibility and actual causality

**Preceding notebook**

- [Actual Causality: the modified Halpern-Pearl definition]() TODO add link


**Summary**

In a previous notebook, we introduced and implemented the Halpern-Pearl modified definition of actual causality. Here we implement the way Halpern used this notion to introduce his so-called *naive definition of responsibility*. We also briefly illustrate some reasons to think a somewhat more sophisticated notion is needed.

**Outline**

[Intuitions](##intuitions)
    
[Formalization](#formalization)

[Implementation](#implementation)

[Examples](#examples)

- [Comments on example selection](#comments-on-example-selection)
  
- [Voting](#voting)

- [Stone-throwing](#stone-throwing)

- [Firing squad](#firing-squad)


## Intuitions

The key idea here is that your responsibility for an outcome is to be measured in terms of how drastic a change would have to be made to the world for the outcome to depend counterfactually on your actions. However, the definition uses a fairly crude measure thereof, the minimal *number* of changes needed, where those numbers are individuated in terms of nodes. On one hand, if you are part of a cause, we count how many elements the cause has. On the other, we count the number of nodes that a witness set has. We add these two numbers for any combination of an actual cause and a witness set and we take the minimum, say $k$. Your responsibility is then $1/k$. 

## Formalization

The degree of responsibility of $X = x$ for $\varphi$ in $\langle M, \vec{u}\rangle$ is 0 if $X = x$ is not part of an actual cause of $\varphi$ in $\langle M, \vec{u}\rangle$ according
to the modified HP definition. It is $1/k$ if there exists an actual cause $\vec{X} = \vec{x}$ of $\varphi$ and a witness $\vec{W}$ to $\vec{X}=\vec{x}$ being a cause of $\varphi$ in $\langle M, \vec{u}\rangle$ such that 
(a) $X=x$ is a conjunct in $\vec{X}= \vec{x}$, (b) $\vert \vec{W}\vert + \vert\vec{X}\vert = k$, and (c) $k$ is minimal such a number.


## Implementation

In [1]:
import functools

import numpy as np
from itertools import combinations

import torch
from typing import Dict, List, Optional,  Union, Callable, Any

import pandas as pd

import pyro
import pyro.distributions as dist

from chirho.indexed.ops import IndexSet, gather, indices_of, scatter
from chirho.interventional.handlers import do
from chirho.counterfactual.ops import preempt
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions
from chirho.observational.handlers import condition

In [2]:
class BiasedPreemptions(Preemptions):
    """
    Counterfactual handler that preempts the model with a biased coin flip.
    """
    def __init__(self, actions, weights: torch.Tensor,  event_dim: int = 0) -> None:
        self.weights = weights
        self.event_dim = event_dim
        super().__init__(actions)

    def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:
        try:
            name = msg["name"]
            action = self.actions[name]
        except KeyError:
            return
        value = msg["value"]
        factual_value = gather(value, IndexSet(**{name: {0}}),
                                event_dim=self.event_dim),
        counterfactual_value = gather(value, IndexSet(**{name: {1}}),
                                event_dim=self.event_dim),
        factual_value = preempt(
            factual_value,
            (action,),
            None,
            event_dim=len(msg["fn"].event_shape),
            name=msg["name"],
        )

        msg["value"] = scatter({
            IndexSet(**{name: {0}}): factual_value,
            IndexSet(**{name: {1}}): counterfactual_value,
        }, event_dim=self.event_dim)

    def _pyro_preempt(self,msg: Dict[str, Any]) -> None:
        if msg["name"] not in self.actions:
            return
        obs, acts, case = msg["args"]
        msg["kwargs"]["name"] = f"__split_{msg['name']}"
        case_dist = pyro.distributions.Categorical(self.weights)
        case = pyro.sample(msg["kwargs"]["name"], case_dist, obs=case)
        msg["args"] = (obs, acts, case)
        msg["stop"] = True

In [3]:
# slight modification wrt. to the orginal class:
# we use BiasedPreemption for witnesses
# with the same intervention bias as for antecedents
# to minimize for the number of active (antecedents + witnesses)

class HalpernPearlModifiedApproximate:

    def __init__(
        self, 
        model: Callable,
        counterfactual_antecedents: Dict[str, torch.Tensor],
        outcome: str,
        witness_candidates: List[str],
        observations: Optional[Dict[str, torch.Tensor]] = None
        ):
        
        if observations is None:
            observations = {}

        self.model = model
        self.counterfactual_antecedents = counterfactual_antecedents
        self.outcome = outcome
        self.witness_candidates = witness_candidates
        self.observations = observations

        self.antecedent_preemptions = {antecedent: functools.partial(self.preempt_with_factual,
                                                                     antecedents = [antecedent]) for
                                                                     antecedent in self.counterfactual_antecedents.keys()}
    
        self.witness_preemptions = {candidate: functools.partial(self.preempt_with_factual,
                                             antecedents = self.counterfactual_antecedents) for 
                                             candidate in self.witness_candidates}
        
    @staticmethod   
    def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
        if antecedents is None:
            antecedents = []

        antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

        factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                                event_dim=event_dim)
            
        return scatter({
            IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
            IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
        }, event_dim=event_dim)
     
        
    def __call__(self, *args, **kwargs):
        with MultiWorldCounterfactual():
            with do(actions=self.counterfactual_antecedents):
                # the last element of the tensor is the factual case (preempted)
                with BiasedPreemptions(actions = self.antecedent_preemptions, weights = torch.tensor([.4, .6])):
                    #the last element is the fixed at the observed value (preempted)            
                    with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.6, .4])):
                        with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):
                            with pyro.poutine.trace() as self.trace:
                                self.consequent = self.model(*args, **kwargs)[self.outcome]
                                self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.counterfactual_antecedents}))
                                self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.counterfactual_antecedents}))
                                self.consequent_differs = self.intervened_consequent != self.observed_consequent               
                                pyro.deterministic("consequent_differs_binary", self.consequent_differs, event_dim = 0) #feels inelegant
                                pyro.factor("consequent_differs", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))


In [4]:
# this will explore the trace once we run inference on the model

def get_table(nodes, antecedents, witness_candidates):
    
    values_table = {}

    for antecedent in antecedents:
        values_table[antecedent] = nodes[antecedent]["value"].squeeze().tolist()
        values_table['preempted_' + antecedent] = nodes['__split_' + antecedent]["value"].squeeze().tolist()
        values_table['preempted_' + antecedent + '_log_prob'] = nodes['__split_' + antecedent]["fn"].log_prob(nodes['__split_' + antecedent]["value"]).squeeze().tolist()


    for candidate in witness_candidates:
        _values = nodes[candidate]["value"].squeeze().tolist()
        # TODO: uncomment in the final version (?) 
        #values_table[candidate + '0'] = _values[0]
        #values_table[candidate + '1'] = _values[1]
        values_table['fixed_factual_' + candidate] = nodes['__split_' + candidate]["value"].squeeze().tolist()
 
    # TODO uncomment in the final version (?)
    #values_table[consequent + '0'] = nodes[consequent]["value"].squeeze().tolist()[0]
    #values_table[consequent + '1'] = nodes[consequent]["value"].squeeze().tolist()[1]
    values_table['consequent_differs_binary'] = nodes['consequent_differs_binary']["value"].squeeze().tolist()
    values_table['consequent_log_prob'] = nodes['consequent_differs']["fn"].log_prob(nodes['consequent_differs']["value"]).squeeze().tolist()

    if isinstance(values_table['consequent_log_prob'], float):
        values_df = pd.DataFrame([values_table])
    else:
        values_df = pd.DataFrame(values_table)
    

    summands = ['preempted_' + antecedent + '_log_prob' for antecedent in antecedents]
    summands.append('consequent_log_prob')
    values_df["sum_log_prob"] =  values_df[summands].sum(axis = 1) 
    values_df.drop_duplicates(inplace = True)
    values_df.sort_values(by = "sum_log_prob", inplace = True, ascending = False)

    return values_df.reset_index(drop = True)

In [5]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))
    u_vote6 = pyro.sample("u_vote6", dist.Bernoulli(0.6))
    u_vote7 = pyro.sample("u_vote7", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)
    vote6 = pyro.deterministic("vote6", u_vote6, event_dim=0)
    vote7 = pyro.deterministic("vote7", u_vote7, event_dim=0)


    outcome = pyro.deterministic("outcome", vote0 + vote1 + vote2 + vote3 + 
                                 vote4 + vote5 + vote6 + vote7 > 4)
    return {"outcome": outcome.float()}

In [6]:
# if you're one of five voters who voted for, you are an actual cause

# and your responsibility is 1 

observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote5=0.,
                        u_vote6=0., u_vote7=0.)



counterfactual_antecedents = {key[2:]: 1-v for key, v in observations.items()}


voting5HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    counterfactual_antecedents = counterfactual_antecedents,
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,8)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote5=0.,
                        u_vote6=0., u_vote7=0.))


In [7]:
with pyro.plate("runs", 10):
    voting5HPM()


NotImplementedError: intervene not implemented for type <class 'tuple'>

This implementation is now used within another class definition, where, again, the main moves are in `def __call__`. We sample antecedent sets, leave other nodes (aside from the outcome) as witness candidates, and pass the result to an actual causality evaluation, keeping track of minimal antecedent sets and the corresponding witness sizes. Then we find a minimum of the sum and use it in the denominator.`

In [5]:
class HalpernPearlResponsibilityApproximate:

    def __init__(
        self, 
        model: Callable,
        nodes: List,
        antecedent: str,
        outcome: str,
        observations: Dict[str, torch.Tensor], 
        runs_n: int 
    ):
        self.model = model
        self.nodes = nodes
        self.antecedent = antecedent
        self.outcome = outcome
        self.observations = observations
        self.runs_n = runs_n
        
        self.minimal_antecedents_cache = []
        self.antecedent_sizes = []
        self.existential_but_fors = []
        self.acs = []
        self.minimal_witness_sizes = []
        self.responsibilities = []
        self.HPMs = []

    def __call__(self):
        
        for step in range(1,self.runs_n):

            nodes = self.nodes
            if self.outcome in nodes:
                nodes.remove(self.outcome) 
            
            companion_size = random.randint(0,len(nodes))
            companion_candidates = random.sample(self.nodes, companion_size)
            witness_candidates = [node for node in self.nodes if 
                                node != self.antecedent and 
                                node != self.outcome and 
                                    node not in companion_candidates]

            HPM = HalpernPearlModifiedApproximate(
                model = self.model,
                antecedents = companion_candidates,
                outcome = self.outcome,
                witness_candidates = witness_candidates,
                observations = self.observations,
                sample_size = 1000)
            
            HPM()
            self.HPMs.append(HPM)

            if HPM.existential_but_for:
                
                HPM_min = ac_minimality_check(HPM)

                if  HPM_min.ac:

                    subset_in_cache = any([s.issubset(set(HPM.antecedents)) for s in self.minimal_antecedents_cache])
                    if not subset_in_cache:
                        for s in self.minimal_antecedents_cache:
                            if set(HPM.antecedents).issubset(s):
                                self.minimal_antecedents_cache.remove(s)
                        self.minimal_antecedents_cache.append(set(HPM.antecedents))

                        if self.antecedent in HPM.antecedents:
                            self.antecedent_sizes.append(len(HPM.antecedents))
                            self.existential_but_fors.append(HPM.existential_but_for)
                            self.acs.append(HPM.ac)
                            self.minimal_witness_sizes.append(HPM.minimal_witness_size)
                            self.responsibilities.append(HPM.responsibility_internal)


        self.denumerators = [x + y for x, y in zip(self.antecedent_sizes, self.minimal_witness_sizes)]

        self.responsibilityDF = pd.DataFrame(
            {#"existential_but_for": [bool(value) for value in self.existential_but_fors],
             "acs": [bool(value) for value in self.acs],
                "antecedent_size": self.antecedent_sizes, 
                "minimal_witness_size": self.minimal_witness_sizes,
                "denumerator": self.denumerators,
                "responsibility": self.responsibilities
            }
            )
        if len(self.responsibilityDF['acs']) == 0:
            self.responsibility = 0
        else:
            min_denumerator = min(self.responsibilityDF['denumerator'])
            self.responsibility = 1/min_denumerator

 

## Examples

### Comments on example selection



- **Voting:** the example illustrates that parts of actual causes can share various degrees of responsibility for the outcome, without being actual causes.

- **Stone-throwing:** responsibility calculations in one of the main running examples in the *Actual Causality* book by Halpern (2016).

- **Firing squad:** an example in which responsibility and actual causality agree, where-as disussed in the notebook on the notion of blame-the notion of responsibility and blame will diverge.

### Voting

We discussed a similar model in a previous notebook. This time we have eight voters involved in a binary majority voting procedure and we investigate the responsibility assigned to voter 0. The situation is analogous to the one discussed in the actual causality notebook: if your vote is decisive, you are an actual cause, and you're not an actual cause otherwise. What's your responsibility, though? 

In [10]:
get_table(voting5HPM.trace.trace.nodes, antecedents = counterfactual_antecedents, witness_candidates = voting5HPM.witness_candidates)

KeyError: '__split_vote0'

In [50]:
# if everyone voted for, you are not an actual cause

everyone_voted_HPR = HalpernPearlResponsibilityApproximate(
    model = voting_model,
    nodes = [f"vote{i}" for i in range(0,8,)],
    antecedent = "vote0", outcome = "outcome",
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
    u_vote3=1., u_vote4=1., u_vote5=1., u_vote6 = 1., u_vote7 = 1.), 
    runs_n=500
    )

pyro.set_rng_seed(42)
everyone_voted_HPR()


In [52]:
# but the size-minimal actual causes are all of size 4
# so your responsibility is 1/4

everyone_voted_HPR.responsibilityDF


Unnamed: 0,acs,antecedent_size,minimal_witness_size,denumerator,responsibility
0,True,4,0,4,0.25
1,True,4,0,4,0.25
2,True,4,0,4,0.25
3,True,4,0,4,0.25
4,True,4,0,4,0.25
5,True,4,0,4,0.25
6,True,4,0,4,0.25
7,True,4,0,4,0.25
8,True,4,0,4,0.25
9,True,4,0,4,0.25


In [53]:
# four people would need to change their votes
# to change the outcome
# so your responsibility is 1/4

everyone_voted_HPR.responsibility

0.25

In [54]:
# if only seven people voted for, 
# your responsibility changes to 1/3

seven_voted_for_HPR = HalpernPearlResponsibilityApproximate(
    model = voting_model,
    nodes = [f"vote{i}" for i in range(0,8,)],
    antecedent = "vote0", outcome = "outcome",
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
    u_vote3=1., u_vote4=1., u_vote5=1., u_vote6 = 1., u_vote7 = 0.), 
    runs_n=500
    )

pyro.set_rng_seed(42)
seven_voted_for_HPR()

In [56]:
# your responsibility is 1/3 as in this case
# it would be enough for three people to vote against
# to change the outcome

seven_voted_for_HPR.responsibility

0.3333333333333333

### Stone-throwing


We've already discussed the model in the actual causality notebook. Sally and Bill throw stones at a bottle, Sally throws first. Bill is perfectly accurate, so his stone would have shattered the bottle had not Sally's stone done it. The model is worth looking at, as the causal structure is less trivial. Again, we will see that responsibility judgment might to some extent disagree with actual causality.

In [57]:
def stones_model():        
    prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
    prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
    prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
    prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
    prob_bottle_shatters_if_sally = pyro.sample("prob_bottle_shatters_if_sally", dist.Beta(1, 1))
    prob_bottle_shatters_if_bill = pyro.sample("prob_bottle_shatters_if_bill", dist.Beta(1, 1))


    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))

    new_shp = torch.where(sally_throws == 1,prob_sally_hits , 0.0)

    sally_hits = pyro.sample("sally_hits",dist.Bernoulli(new_shp))

    new_bhp = torch.where(
            (
                bill_throws.bool()
                & (~sally_hits.bool())
            )
            == 1,
            prob_bill_hits,
            torch.tensor(0.0),
        )


    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))

    new_bsp = torch.where(
            bill_hits.bool() == 1,
            prob_bottle_shatters_if_bill,
            torch.where(
                sally_hits.bool() == 1,
                prob_bottle_shatters_if_sally,
                torch.tensor(0.0),
            ),
        )

    bottle_shatters = pyro.sample(
            "bottle_shatters", dist.Bernoulli(new_bsp)
        )

    return {
            "sally_throws": sally_throws,
            "bill_throws": bill_throws,
            "sally_hits": sally_hits,
            "bill_hits": bill_hits,
            "bottle_shatters": bottle_shatters,
        }

stones_model.nodes = [
            "sally_throws",
            "bill_throws",
            "sally_hits",
            "bill_hits",
            "bottle_shatters",
        ]

In [58]:
pyro.set_rng_seed(101)
responsibility_stones_sally_HPR = HalpernPearlResponsibilityApproximate(
    model = stones_model,
    nodes = stones_model.nodes,
    antecedent = "sally_throws", outcome = "bottle_shatters",
    observations = {"prob_sally_throws": 1, 
                    "prob_bill_throws": 1,
                    "prob_sally_hits": 1,
                    "prob_bill_hits": 1,
                    "prob_bottle_shatters_if_sally": 1,
                    "prob_bottle_shatters_if_bill": 1,
                    "sally_throws": 1, "bill_throws": 1},
                      runs_n=100)

responsibility_stones_sally_HPR()

In [59]:
# minimal witness size becomes non-trivial here
# we only record different minimal difference-making scenarios

responsibility_stones_sally_HPR.responsibilityDF

Unnamed: 0,acs,antecedent_size,minimal_witness_size,denumerator,responsibility
0,True,1,1,2,0.5


In [60]:
# following Halpern
# Sally's responsibility is 1/2

responsibility_stones_sally_HPR.responsibility

0.5

In [68]:

# Billy has degree of responsibility 0
# for the bottle shattering,
# as his throw is not a part of an actual cause

pyro.set_rng_seed(102)

responsibility_stones_bill_HPR = HalpernPearlResponsibilityApproximate(
    model = stones_model,
    nodes = stones_model.nodes,
    antecedent = "bill_throws", outcome = "bottle_shatters",
    observations = {"prob_sally_throws": 1, 
                    "prob_bill_throws": 1,
                    "prob_sally_hits": 1,
                    "prob_bill_hits": 1,
                    "prob_bottle_shatters_if_sally": 1,
                    "prob_bottle_shatters_if_bill": 1,
                    "sally_throws": 1, "bill_throws": 1},
                      runs_n=10)

In [66]:
pyro.set_rng_seed(101)
responsibility_stones_bill_HPR()

In [67]:
responsibility_stones_bill_HPR.responsibility

0

### Firing squad

There is a firing squad consisting of five excellent marksmen. Only one of them has a live bullet in his rifle and the rest have blanks. They shoot and the prisoner dies. The marksmen shoot at the prisoner and he dies. The only cause of the prisoner’s death is the marksman with the live bullet. That marksman has degree of responsibility 1 for the death and all the others have degree of responsibility 0. In the notebook on blame, TODO add link we will see that if the marksmen completely do not know which of them has the live bullet, blame is nevertheless equally distributed between them.

In [69]:

def firing_squad_model():
    probs = pyro.sample("probs", dist.Dirichlet(torch.ones(5)))

    who_has_bullet = pyro.sample("who_has_bullet", dist.OneHotCategorical(probs))

    mark0 = pyro.deterministic("mark0", torch.tensor([who[0] for who in who_has_bullet]), event_dim=0)
    mark1 = pyro.deterministic("mark1", torch.tensor([who[1] for who in who_has_bullet]), event_dim=0)
    mark2 = pyro.deterministic("mark2", torch.tensor([who[2] for who in who_has_bullet]), event_dim=0)
    mark3 = pyro.deterministic("mark3", torch.tensor([who[3] for who in who_has_bullet]), event_dim=0)
    mark4 = pyro.deterministic("mark4", torch.tensor([who[4] for who in who_has_bullet]), event_dim=0)

    dead = pyro.deterministic("dead", mark0 + mark1 + mark2 + mark3 + 
                                mark4  > 0)
    
    return {"probs": probs,
            "mark0": mark0,
            "mark1": mark1,
            "mark2": mark2,
            "mark3": mark3,
            "mark4": mark4, 
            "dead": dead}



In [70]:
pyro.set_rng_seed(102)

responsibility_loaded_HPR = HalpernPearlResponsibilityApproximate(
    model = firing_squad_model,
    nodes = ["mark" + str(i) for i in range(0,5)],
    antecedent = "mark0", outcome = "dead",
    observations = {"probs": torch.tensor([1., 0., 0., 0., 0.]),},
                      runs_n=50)


In [71]:
pyro.set_rng_seed(102)

responsibility_empty_HPR = HalpernPearlResponsibilityApproximate(
    model = firing_squad_model,
    nodes = ["mark" + str(i) for i in range(0,5)],
    antecedent = "mark1", outcome = "dead",
    observations = {"probs": torch.tensor([1., 0., 0., 0., 0.]),},
                      runs_n=50)

In [75]:
# If you have the live bullet

responsibility_loaded_HPR()
responsibility_loaded_HPR.responsibility

1.0

In [74]:
# if you have a blank,
# as we keep bullet's location constant in the model
# nothing can make a difference to mark1's contribution
# so his responsibility is zero

responsibility_empty_HPR()
responsibility_empty_HPR.responsibility

0