# Responsibility and actual causality

**Preceding notebook**

- [Actual Causality: the modified Halpern-Pearl definition]() TODO add link


**Summary**

In a previous notebook, we introduced and implemented the Halpern-Pearl modified definition of actual causality. Here we implement the way Halpern used this notion to introduce his so-called *naive definition of responsibility*. We also briefly illustrate some reasons to think a somewhat more sophisticated notion is needed.

**Outline**

[Intuitions](##intuitions)
    
[Formalization](#formalization)

[Implementation](#implementation)

[Examples](#examples)

- [Comments on example selection](#comments-on-example-selection)
  
- [Voting](#voting)

- [Stone-throwing](#stone-throwing)

- [Firing squad](#firing-squad)


## Intuitions

The key idea here is that your responsibility for an outcome is to be measured in terms of how drastic a change would have to be made to the world for the outcome to depend counterfactually on your actions. However, the definition uses a fairly crude measure thereof, the minimal *number* of changes needed, where those numbers are individuated in terms of nodes. On one hand, if you are part of a cause, we count how many elements the cause has. On the other, we count the number of nodes that a witness set has. We add these two numbers for any combination of an actual cause and a witness set and we take the minimum, say $k$. Your responsibility is then $1/k$. 

## Formalization

The degree of responsibility of $X = x$ for $\varphi$ in $\langle M, \vec{u}\rangle$ is 0 if $X = x$ is not part of an actual cause of $\varphi$ in $\langle M, \vec{u}\rangle$ according
to the modified HP definition. It is $1/k$ if there exists an actual cause $\vec{X} = \vec{x}$ of $\varphi$ and a witness $\vec{W}$ to $\vec{X}=\vec{x}$ being a cause of $\varphi$ in $\langle M, \vec{u}\rangle$ such that 
(a) $X=x$ is a conjunct in $\vec{X}= \vec{x}$, (b) $\vert \vec{W}\vert + \vert\vec{X}\vert = k$, and (c) $k$ is minimal such a number.


## Implementation

In [1]:
import functools

import numpy as np
from itertools import combinations
import math

import torch
from typing import Dict, List, Optional,  Union, Callable, Any

import pandas as pd

import pyro
import pyro.distributions as dist

from chirho.indexed.ops import IndexSet, gather, indices_of, scatter
from chirho.interventional.handlers import do
from chirho.counterfactual.ops import preempt, intervene
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions
from chirho.observational.handlers import condition


In [2]:

class BiasedPreemptions(Preemptions):
    """
    Counterfactual handler that preempts the model with a biased coin flip.
    """
    def __init__(self, actions, weights: torch.Tensor,  event_dim: int = 0, prefix: str = "__split_") -> None:
        self.weights = weights
        self.event_dim = event_dim
        self.prefix = prefix
        super().__init__(actions)


    def _pyro_preempt(self,msg: Dict[str, Any]) -> None:
        if msg["name"] not in self.actions:
            return   

        obs, acts, case = msg["args"]
        msg["kwargs"]["name"] = f"{self.prefix}{msg['name']}"
        case_dist = pyro.distributions.Categorical(self.weights)
        #print(msg["kwargs"]["name"] , self.prefix, msg['name'], self.weights)
        case = pyro.sample(msg["kwargs"]["name"], case_dist, obs=case)
        msg["args"] = (obs, acts, case)
        msg["stop"] = True

    def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:
        with pyro.poutine.messenger.block_messengers(
            lambda m : (isinstance(m, Preemptions) and (m is not self))
        ):
            super()._pyro_post_sample(msg) 

In [10]:
class HalpernPearlResponsibilityApproximate:

    def __init__(
        self, 
        model: Callable,
        evaluated_node_counterfactual: Dict[str, torch.Tensor],
        treatment_candidates: Dict[str, torch.Tensor],
        witness_candidates: List[str],
        outcome: str,
        observations: Optional[Dict[str, torch.Tensor]] = None,
        bias_n: float = .2
        ):
        
        if observations is None:
            observations = {}

        #if not set(witness_candidates) <= set(treatment_candidates.keys()):
        #    raise ValueError("witness_candidates must be a subset of treatment_candidates.keys().")
        
        self.model = model
        self.evaluated_node_counterfactual = evaluated_node_counterfactual
        self.treatment_candidates = treatment_candidates
        self.witness_candidates = witness_candidates
        self.outcome = outcome
        self.observations = observations
        self.bias_t = .2
        self.bias_n = self.find_max_bias_within(self.bias_t, len(self.treatment_candidates))
        self.bias_w = self.find_max_bias_within(self.bias_n, len(self.witness_candidates))

        self.evaluated_node_preemptions = {node: functools.partial(self.preempt_with_factual,
                                                                    antecedents = [node]) for
                                                                    node in self.evaluated_node_counterfactual.keys()}

        self.treatment_preemptions = {antecedent: functools.partial(self.preempt_with_factual,
                                                                     antecedents = [antecedent]) for
                                                                     antecedent in self.treatment_candidates.keys()}
    
        self.witness_preemptions = {candidate: functools.partial(self.preempt_with_factual,
                                             antecedents = self.treatment_candidates) for 
                                             candidate in self.witness_candidates}
        
    @staticmethod
    def find_max_bias_within(e: float, n: int,
    max_iterations: int = 1000, learning_rate: float = 0.002):
    
        ediff = math.log(0.5 + e) - math.log(0.5 - e)
        #print("up", math.log(0.5 + e), "down", math.log(0.5 - e), "ediff", ediff)

        w = e
        wdiff = math.log(0.5 + w) - math.log(0.5 - w)

        iteration = 0 
        while iteration < max_iterations and ediff <= n * wdiff:
       
            distance = n * wdiff / ediff
            assert w - learning_rate * distance >0 , "The learning rate is too high."
            w -= learning_rate * distance
        
            wdiff = math.log(0.5 + w) - math.log(0.5 - w)
            #print("up", math.log(0.5 + w), "down", math.log(0.5 - w), "wdiff", wdiff, "nwdiff", n * wdiff)

            iteration += 1
            
        return w

    @staticmethod   
    def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
        if antecedents is None:
            antecedents = []

        antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

        factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                                event_dim=event_dim)
            
        return scatter({
            IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
            IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
        }, event_dim=event_dim)
     
        
    def __call__(self, *args, **kwargs):
        print("Preemption biases used (upper) - t:",.5+ self.bias_t, ", n:", .5 + self.bias_n, ", w:", .5 + self.bias_w, ".")
        with MultiWorldCounterfactual():
            with do(actions=self.evaluated_node_counterfactual):
                with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),
                                                    prefix = "__evaluated_split_"):
                    with do(actions=self.treatment_candidates):
                        with BiasedPreemptions(actions = self.treatment_preemptions, weights = torch.tensor([.5-self.bias_t, .5+self.bias_t]),
                                                        prefix = "__treatment_split_"):
                                # the last element is the fixed at the observed value (preempted)  
                                # the last element of the tensor is the factual case (preempted)
                            with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.5 + self.bias_w, .5-self.bias_w]),
                                           prefix = "__witness_split_"):

                                with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):
                                    with pyro.poutine.trace() as self.trace:
                                        self.run = self.model(*args, **kwargs)
                                        self.consequent = self.run[self.outcome]
                                        self.interventionIndex = { intervention: {1} for intervention 
                                                                     in list(self.evaluated_node_counterfactual.keys()) + 
                                                                            list(self.treatment_candidates.keys()) +  self.witness_candidates}
                                    
                                        self.observedIndex = {node: {0} for node in list(self.evaluated_node_counterfactual.keys()) + 
                                                                            list(self.treatment_candidates.keys()) + self.witness_candidates}

                                        
                                        self.intervened_consequent = gather(self.consequent, IndexSet(**self.interventionIndex))
                                        
                                        self.observed_consequent = gather(self.consequent, IndexSet(**self.observedIndex))
                                        self.consequent_differs = self.intervened_consequent != self.observed_consequent               
                                        pyro.deterministic("consequent_differs_binary", self.consequent_differs, event_dim = 0) #feels inelegant
                                        pyro.factor("consequent_differs", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))

In [11]:
# only needed for ease of exposition,
# not for the inference itself

def remove_redundant_rows(tab):
    existing_pairs = []

    for col in tab.columns:
        if col[0:4] == "apr_":
            ending = col.split("apr_")[1]
            wpr_col = f"wpr_{ending}"
            if wpr_col in tab.columns:
                existing_pairs.append((col,wpr_col))

    keep = []
    for index, row in tab.iterrows():
        
        flag = True
        for pair in existing_pairs:
            apr_col = pair[0]
            wpr_col = pair[1]
            apr_value = row[apr_col]
            wpr_value = row[wpr_col]
    
            if apr_value == 0 and wpr_value == 1:
                flag = False
                break
        keep.append(flag)
   
    return(tab[keep])

In [12]:
# this will explore the trace once we run inference on the model

def get_table(nodes, evaluated_node, antecedents, witness_candidates, round = True):
    
    values_table = {}

    values_table[f"obs_{evaluated_node}"] = nodes[evaluated_node]["value"][0].squeeze().tolist()
    values_table[f"int_{evaluated_node}"] = nodes[evaluated_node]["value"][1].squeeze().tolist()
    values_table[f"epr_{evaluated_node}"] = nodes[f"__evaluated_split_{evaluated_node}"]["value"].squeeze().tolist()
    values_table[f"elp_{evaluated_node}"] = nodes[f"__evaluated_split_{evaluated_node}"]["fn"].log_prob(nodes[f"__evaluated_split_{evaluated_node}"]["value"]).squeeze().tolist()

    for antecedent in antecedents:
        values_table[f"obs_{antecedent}"] = nodes[antecedent]["value"][0].squeeze().tolist()
        values_table[f"int_{antecedent}"] = nodes[antecedent]["value"][1].squeeze().tolist()
        values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent]["value"].squeeze().tolist()
        values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent]["fn"].log_prob(nodes['__treatment_split_' + antecedent]["value"]).squeeze().tolist()



        if f"__witness_split_{antecedent}" in nodes.keys():
            values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent]["value"].squeeze().tolist()
            values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent]["fn"].log_prob(nodes['__witness_split_' + antecedent]["value"]).squeeze().tolist()

    
    values_table['cdif'] = nodes['consequent_differs_binary']["value"].squeeze().tolist()
    values_table['clp'] = nodes['consequent_differs']["fn"].log_prob(nodes['consequent_differs']["value"]).squeeze().tolist()

    if isinstance(values_table['clp'], float):
        values_df = pd.DataFrame([values_table])
    else:
        values_df = pd.DataFrame(values_table)
    
    values_df = pd.DataFrame(values_table)

    summands_ant = ['alp_' + antecedent for antecedent in antecedents]
    summands_wit = ['wlp_' + witness for witness in witness_candidates]
    summands = [f"elp_{evaluated_node}"] +  summands_ant + summands_wit + ['clp']
    
    
    values_df["int"] =  values_df.apply(lambda row: sum(row[row.index.str.startswith("apr_")] == 0), axis=1)
    values_df['int'] = 1 - values_df[f"epr_{evaluated_node}"] + values_df["int"]
    values_df["wpr"] = values_df.apply(lambda row: sum(row[row.index.str.startswith("wpr_")] == 1), axis=1)
    values_df["changes"] =   values_df["int"] + values_df["wpr"]


    values_df["sum_lp"] =  values_df[summands].sum(axis = 1) 
    values_df.drop_duplicates(inplace = True)
    values_df.sort_values(by = "sum_lp", inplace = True, ascending = False)

    tab =  values_df.reset_index(drop = True)

    tab = remove_redundant_rows(tab)

    #tab = values_table

    if round:
        tab = tab.round(3)

    return tab


In [13]:
def responsibility_check(hpr):

    evaluated_node = list(hpr.evaluated_node_counterfactual.keys())[0]
    tab = get_table(hpr.trace.trace.nodes,
                    evaluated_node ,
                    list(hpr.treatment_candidates.keys()), 
                    hpr.witness_candidates)
    
    max_sum_lp = tab['sum_lp'].max()
    max_sum_lp_rows = tab[tab['sum_lp'] == max_sum_lp]

    map_estimate = 1/ tab['changes'][0]

    print (f"MAP estimate: {map_estimate}")

    # sanity check; consider removing later
    min_changes = max_sum_lp_rows['changes'].min()
    min_changes_row = max_sum_lp_rows[max_sum_lp_rows['changes'] == min_changes]

    print("Minimal scenarios:")
    print(min_changes_row)

    if not (min_changes_row[f'int_{evaluated_node}'] == 0).any():
        print (f"No MAP estimate includes intervention on int_{evaluated_node} == 0")
        return 0
    
    min_changes_row = min_changes_row[min_changes_row[f'int_{evaluated_node}'] == 0]

    secondary_check = 1/min_changes_row['changes'].min()

    print (f"Secondary check: {secondary_check}")

    assert map_estimate == secondary_check, "MAP estimate does not match secondary check, increase sample size."  

    return map_estimate

#TODO THIS NEEDS TO BE UPDATED 
This implementation is now used within another class definition, where, again, the main moves are in `def __call__`. We sample antecedent sets, leave other nodes (aside from the outcome) as witness candidates, and pass the result to an actual causality evaluation, keeping track of minimal antecedent sets and the corresponding witness sizes. Then we find a minimum of the sum and use it in the denominator.`

## Examples

### Comments on example selection



- **Voting:** the example illustrates that parts of actual causes can share various degrees of responsibility for the outcome, without being actual causes.

- **Stone-throwing:** responsibility calculations in one of the main running examples in the *Actual Causality* book by Halpern (2016).

- **Firing squad:** an example in which responsibility and actual causality agree, where-as disussed in the notebook on the notion of blame-the notion of responsibility and blame will diverge.

### Voting

We discussed a similar model in a previous notebook. This time we have eight voters involved in a binary majority voting procedure and we investigate the responsibility assigned to voter 0. The situation is analogous to the one discussed in the actual causality notebook: if your vote is decisive, you are an actual cause, and you're not an actual cause otherwise. What's your responsibility, though? 

In [14]:
# let's start with a minimal interesting example
# you are one of three voters

def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)

    outcome = pyro.deterministic("outcome", vote0 + vote1 + vote2 >1
                                 )
    return {"outcome": outcome.float()}

observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.
                        )

treatment_candidates = {key[2:]: 1-v for key, v in observations.items() if key  != "u_vote0"}
evaluated_node_counterfactual = {"vote0": 1 - observations["u_vote0"]}

votingHPR = HalpernPearlResponsibilityApproximate(
    model = voting_model,
    evaluated_node_counterfactual = evaluated_node_counterfactual,
    treatment_candidates = treatment_candidates,
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,3)],
    observations = observations)


In [15]:
# Now inference, and let's inspect the table while it's small

with pyro.plate("runs", 1000):
    votingHPR()

vtr = votingHPR.trace.trace.nodes

get_table(vtr, "vote0", treatment_candidates, [f"vote{i}" for i in range(1,3)])

Preemption biases used (upper) - t: 0.7 , n: 0.6024412643276109 , w: 0.5502509213795265 .


Unnamed: 0,obs_vote0,int_vote0,epr_vote0,elp_vote0,obs_vote1,int_vote1,apr_vote1,alp_vote1,wpr_vote1,wlp_vote1,...,apr_vote2,alp_vote2,wpr_vote2,wlp_vote2,cdif,clp,int,wpr,changes,sum_lp
0,1.0,0.0,0,-0.922,1.0,1.0,1,-0.357,0,-0.597,...,0,-1.204,0,-0.597,True,0.0,2,0,2,-3.678
1,1.0,0.0,0,-0.922,1.0,0.0,0,-1.204,0,-0.597,...,1,-0.357,0,-0.597,True,0.0,2,0,2,-3.678
2,1.0,0.0,0,-0.922,1.0,0.0,0,-1.204,0,-0.597,...,1,-0.357,1,-0.799,True,0.0,2,1,3,-3.88
3,1.0,0.0,0,-0.922,1.0,1.0,1,-0.357,1,-0.799,...,0,-1.204,0,-0.597,True,0.0,2,1,3,-3.88
4,1.0,1.0,1,-0.507,1.0,0.0,0,-1.204,0,-0.597,...,0,-1.204,0,-0.597,True,0.0,2,0,2,-4.109
5,1.0,0.0,0,-0.922,1.0,0.0,0,-1.204,0,-0.597,...,0,-1.204,0,-0.597,True,0.0,3,0,3,-4.525
8,1.0,1.0,1,-0.507,1.0,1.0,1,-0.357,0,-0.597,...,1,-0.357,0,-0.597,False,-100000000.0,0,0,0,-100000000.0
9,1.0,1.0,1,-0.507,1.0,1.0,1,-0.357,0,-0.597,...,1,-0.357,1,-0.799,False,-100000000.0,0,1,1,-100000000.0
10,1.0,1.0,1,-0.507,1.0,1.0,1,-0.357,1,-0.799,...,1,-0.357,0,-0.597,False,-100000000.0,0,1,1,-100000000.0
11,1.0,1.0,1,-0.507,1.0,1.0,1,-0.357,1,-0.799,...,1,-0.357,1,-0.799,False,-100000000.0,0,2,2,-100000000.0


# TODO need a brief explanation of what's going on here

In [16]:
responsibility_check(votingHPR)

MAP estimate: 0.5
Minimal scenarios:
   obs_vote0  int_vote0  epr_vote0  elp_vote0  obs_vote1  int_vote1  \
0        1.0        0.0          0     -0.922        1.0        1.0   
1        1.0        0.0          0     -0.922        1.0        0.0   

   apr_vote1  alp_vote1  wpr_vote1  wlp_vote1  ...  apr_vote2  alp_vote2  \
0          1     -0.357          0     -0.597  ...          0     -1.204   
1          0     -1.204          0     -0.597  ...          1     -0.357   

   wpr_vote2  wlp_vote2  cdif  clp  int  wpr  changes  sum_lp  
0          0     -0.597  True  0.0    2    0        2  -3.678  
1          0     -0.597  True  0.0    2    0        2  -3.678  

[2 rows x 22 columns]
Secondary check: 0.5


0.5

In [17]:
# now consider a more complex example,
# with 7 voters, where you are not an actual cause

def voting_model7():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))
    u_vote6 = pyro.sample("u_vote6", dist.Bernoulli(0.6))
  

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)
    vote6 = pyro.deterministic("vote6", u_vote6, event_dim=0)

    outcome = pyro.deterministic("outcome", vote0 + vote1 + vote2 + vote3 + 
                                 vote4  + vote5 + vote6 > 3 )
    return {"outcome": outcome.float()}


In [18]:
# everyone voted for,
# you are not an actual cause 
# the minimal number of interventions 
# including your change of vote
# needed to change the outcome is 4
# so your responsibility is 1/4

observations7 = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1.,
                        u_vote5=1.,
                        u_vote6=1.,
                        )


treatment_candidates7 = {key[2:]: 1-v for key, v in observations7.items() if key  != "u_vote0"}

evaluated_node_counterfactual7 = {"vote0": 1 - observations7["u_vote0"]}

voting8HPR = HalpernPearlResponsibilityApproximate(
    model = voting_model7,
    evaluated_node_counterfactual = evaluated_node_counterfactual7,
    treatment_candidates = treatment_candidates7,
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,7)],
    observations = observations7)

with pyro.plate("runs", 10000):
    voting8HPR()


Preemption biases used (upper) - t: 0.7 , n: 0.5337756999955469 , w: 0.5037726755473835 .


In [19]:
responsibility_check(voting8HPR)

MAP estimate: 0.25
Minimal scenarios:
   obs_vote0  int_vote0  epr_vote0  elp_vote0  obs_vote1  int_vote1  \
0        1.0        0.0          0     -0.763        1.0        0.0   
1        1.0        0.0          0     -0.763        1.0        1.0   
2        1.0        0.0          0     -0.763        1.0        1.0   
3        1.0        0.0          0     -0.763        1.0        0.0   
4        1.0        0.0          0     -0.763        1.0        1.0   
5        1.0        0.0          0     -0.763        1.0        0.0   
6        1.0        0.0          0     -0.763        1.0        0.0   

   apr_vote1  alp_vote1  wpr_vote1  wlp_vote1  ...  apr_vote6  alp_vote6  \
0          0     -1.204          0     -0.686  ...          1     -0.357   
1          1     -0.357          0     -0.686  ...          0     -1.204   
2          1     -0.357          0     -0.686  ...          1     -0.357   
3          0     -1.204          0     -0.686  ...          1     -0.357   
4          1 

0.25

### Stone-throwing


We've already discussed the model in the actual causality notebook. Sally and Bill throw stones at a bottle, Sally throws first. Bill is perfectly accurate, so his stone would have shattered the bottle had not Sally's stone done it. The model is worth looking at, as the causal structure is less trivial. Again, we will see that responsibility judgment might to some extent disagree with actual causality.

In [20]:
def stones_model():        
    prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
    prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
    prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
    prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
    prob_bottle_shatters_if_sally = pyro.sample("prob_bottle_shatters_if_sally", dist.Beta(1, 1))
    prob_bottle_shatters_if_bill = pyro.sample("prob_bottle_shatters_if_bill", dist.Beta(1, 1))


    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))

    new_shp = torch.where(sally_throws == 1,prob_sally_hits , 0.0)

    sally_hits = pyro.sample("sally_hits",dist.Bernoulli(new_shp))

    new_bhp = torch.where(
            (
                bill_throws.bool()
                & (~sally_hits.bool())
            )
            == 1,
            prob_bill_hits,
            torch.tensor(0.0),
        )


    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))

    new_bsp = torch.where(
            bill_hits.bool() == 1,
            prob_bottle_shatters_if_bill,
            torch.where(
                sally_hits.bool() == 1,
                prob_bottle_shatters_if_sally,
                torch.tensor(0.0),
            ),
        )

    bottle_shatters = pyro.sample(
            "bottle_shatters", dist.Bernoulli(new_bsp)
        )

    return {
            "sally_throws": sally_throws,
            "bill_throws": bill_throws,
            "sally_hits": sally_hits,
            "bill_hits": bill_hits,
            "bottle_shatters": bottle_shatters,
        }



In [21]:

pyro.set_rng_seed(4)
stones_sallyHPR = HalpernPearlResponsibilityApproximate(
    model = stones_model,
    evaluated_node_counterfactual= {"sally_throws": 0.0},
    treatment_candidates = {"sally_hits": 0.0, "bill_hits": 1.0, "bill_throws": 0.0},
    outcome = "bottle_shatters",
    witness_candidates = ["bill_hits", "bill_throws", "sally_hits"],
    observations = {"prob_sally_throws": 1.0, 
                    "prob_bill_throws": 1.0,
                    "prob_sally_hits": 1.0,
                    "prob_bill_hits": 1.0,
                    "prob_bottle_shatters_if_sally": 1.0,
                    "prob_bottle_shatters_if_bill": 1.0})

with pyro.plate("runs",10000):
    stones_sallyHPR()

Preemption biases used (upper) - t: 0.7 , n: 0.5693194340229132 , w: 0.5214645870036448 .


In [22]:
def gett(nodes, evaluated_node, antecedents, witness_candidates, round = True):
    
    values_table = {}


#    values_table[f"obs_{evaluated_node}"] = nodes[evaluated_node]["value"][0].squeeze().tolist()
#    values_table[f"int_{evaluated_node}"] = nodes[evaluated_node]["value"][1].squeeze().tolist()
    values_table[f"epr_{evaluated_node}"] = nodes[f"__evaluated_split_{evaluated_node}"]["value"].squeeze().tolist()
    values_table[f"elp_{evaluated_node}"] = nodes[f"__evaluated_split_{evaluated_node}"]["fn"].log_prob(nodes[f"__evaluated_split_{evaluated_node}"]["value"]).squeeze().tolist()

    for antecedent in antecedents:
#        andecedent_m = HPR.run[antecedent]
#        print(gather(andecedent_m, IndexSet(**{antecedent: {0} for antecedent in antecedents})))
#        values_table[f"obs_{antecedent}"] = nodes[antecedent]["value"][0].squeeze().tolist()
#        values_table[f"int_{antecedent}"] = nodes[antecedent]["value"][1].squeeze().tolist()
        values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent]["value"].squeeze().tolist()
        values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent]["fn"].log_prob(nodes['__treatment_split_' + antecedent]["value"]).squeeze().tolist()

        if f"__witness_split_{antecedent}" in nodes.keys():
            values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent]["value"].squeeze().tolist()
            values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent]["fn"].log_prob(nodes['__witness_split_' + antecedent]["value"]).squeeze().tolist()

    for witness in witness_candidates:
        if witness not in antecedents:
            values_table['wpr_' + witness] = nodes['__witness_split_' + witness]["value"].squeeze().tolist()
            values_table['wlp_' + witness] = nodes['__witness_split_' + witness]["fn"].log_prob(nodes['__witness_split_' + witness]["value"]).squeeze().tolist()

    values_table['clp'] = nodes['consequent_differs']["fn"].log_prob(nodes['consequent_differs']["value"]).squeeze().tolist()

    if isinstance(values_table['clp'], float):
       values_df = pd.DataFrame([values_table])
    else:
        values_df = pd.DataFrame(values_table)
    
    values_df = pd.DataFrame(values_table)

    summands_ant = ['alp_' + antecedent for antecedent in antecedents]
    summands_wit = ['wlp_' + witness for witness in witness_candidates]
    summands = [f"elp_{evaluated_node}"] +  summands_ant + summands_wit + ['clp']

    values_df["int"] =  values_df.apply(lambda row: sum(row[row.index.str.startswith("apr_")] == 0), axis=1)
    values_df['int'] = 1 - values_df[f"epr_{evaluated_node}"] + values_df["int"]
    values_df["wpr"] = values_df.apply(lambda row: sum(row[row.index.str.startswith("wpr_")] == 1), axis=1)
    values_df["changes"] =   values_df["int"] + values_df["wpr"]

    values_df["sum_lp"] =  values_df[summands].sum(axis = 1) 
    values_df.drop_duplicates(inplace = True)
    values_df.sort_values(by = "sum_lp", inplace = True, ascending = False)

    tab =  values_df.reset_index(drop = True)

    tab = remove_redundant_rows(tab)
    
    if round:
       tab = tab.round(3)

    return tab


tab = gett(stones_sallyHPR.trace.trace.nodes, "sally_throws", stones_sallyHPR.treatment_candidates, 
          stones_sallyHPR.witness_candidates)

tab

Unnamed: 0,epr_sally_throws,elp_sally_throws,apr_sally_hits,alp_sally_hits,wpr_sally_hits,wlp_sally_hits,apr_bill_hits,alp_bill_hits,wpr_bill_hits,wlp_bill_hits,apr_bill_throws,alp_bill_throws,wpr_bill_throws,wlp_bill_throws,clp,int,wpr,changes,sum_lp
0,1,-0.563,0,-1.204,0,-0.651,1,-0.357,1,-0.737,1,-0.357,0,-0.651,0.0,1,1,2,-4.52
1,1,-0.563,0,-1.204,0,-0.651,1,-0.357,1,-0.737,1,-0.357,1,-0.737,0.0,1,2,3,-4.606
2,0,-0.842,1,-0.357,0,-0.651,1,-0.357,0,-0.651,0,-1.204,0,-0.651,0.0,2,0,2,-4.713
3,0,-0.842,1,-0.357,1,-0.737,1,-0.357,0,-0.651,0,-1.204,0,-0.651,0.0,2,1,3,-4.799
4,1,-0.563,0,-1.204,0,-0.651,1,-0.357,0,-0.651,0,-1.204,0,-0.651,0.0,2,0,2,-5.281
5,1,-0.563,0,-1.204,0,-0.651,1,-0.357,1,-0.737,0,-1.204,0,-0.651,0.0,2,1,3,-5.367
9,0,-0.842,0,-1.204,0,-0.651,1,-0.357,0,-0.651,0,-1.204,0,-0.651,0.0,3,0,3,-5.56
13,1,-0.563,1,-0.357,0,-0.651,1,-0.357,0,-0.651,1,-0.357,0,-0.651,-100000000.0,0,0,0,-100000000.0
14,1,-0.563,1,-0.357,0,-0.651,1,-0.357,1,-0.737,1,-0.357,0,-0.651,-100000000.0,0,1,1,-100000000.0
15,1,-0.563,1,-0.357,0,-0.651,1,-0.357,0,-0.651,1,-0.357,1,-0.737,-100000000.0,0,1,1,-100000000.0


In [23]:
#this is worrying, should be on top with clp == 0

tab.query("epr_sally_throws == 0 & apr_sally_hits == 1 & wpr_sally_hits == 0 & apr_bill_hits == 1 & wpr_bill_hits == 1 & apr_bill_throws == 1 & wpr_bill_throws == 0")

Unnamed: 0,epr_sally_throws,elp_sally_throws,apr_sally_hits,alp_sally_hits,wpr_sally_hits,wlp_sally_hits,apr_bill_hits,alp_bill_hits,wpr_bill_hits,wlp_bill_hits,apr_bill_throws,alp_bill_throws,wpr_bill_throws,wlp_bill_throws,clp,int,wpr,changes,sum_lp
23,0,-0.842,1,-0.357,0,-0.651,1,-0.357,1,-0.737,1,-0.357,0,-0.651,-100000000.0,1,1,2,-100000000.0


In [36]:
stones_sallyHPR.trace.trace.nodes["consequent_differs_binary"]

{'type': 'sample',
 'name': 'consequent_differs_binary',
 'fn': MaskedDistribution(),
 'is_observed': True,
 'args': (),
 'kwargs': {},
 'value': tensor([[[[[False, False, False, False, False, False, False, False, False,
             False]]]]]),
 'infer': {'_deterministic': True},
 'scale': 1.0,
 'mask': None,
 'cond_indep_stack': (CondIndepStackFrame(name='runs', dim=-1, size=10, counter=0),),
 'done': True,
 'stop': False,
 'continuation': None}

In [39]:
def get_table(nodes, evaluated_node, antecedents, witness_candidates, round = True):
    
    values_table = {}

    values_table[f"obs_{evaluated_node}"] = nodes[evaluated_node]["value"][0].squeeze().tolist()
    values_table[f"int_{evaluated_node}"] = nodes[evaluated_node]["value"][1].squeeze().tolist()
    values_table[f"epr_{evaluated_node}"] = nodes[f"__evaluated_split_{evaluated_node}"]["value"].squeeze().tolist()
    values_table[f"elp_{evaluated_node}"] = nodes[f"__evaluated_split_{evaluated_node}"]["fn"].log_prob(nodes[f"__evaluated_split_{evaluated_node}"]["value"]).squeeze().tolist()

    for antecedent in antecedents:
        values_table[f"obs_{antecedent}"] = nodes[antecedent]["value"][0].squeeze().tolist()
        values_table[f"int_{antecedent}"] = nodes[antecedent]["value"][1].squeeze().tolist()
        values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent]["value"].squeeze().tolist()
        values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent]["fn"].log_prob(nodes['__treatment_split_' + antecedent]["value"]).squeeze().tolist()



        if f"__witness_split_{antecedent}" in nodes.keys():
            values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent]["value"].squeeze().tolist()
            values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent]["fn"].log_prob(nodes['__witness_split_' + antecedent]["value"]).squeeze().tolist()

        for witness in witness_candidates:
            if witness not in antecedents:
                values_table[f"obs_{witness}"] = nodes[witness]["value"][0].squeeze().tolist()
                #values_table[f"int_{witness}"] = nodes[witness]["value"][1].squeeze().tolist()
                values_table['wpr_' + witness] = nodes['__witness_split_' + witness]["value"].squeeze().tolist()
                values_table['wlp_' + witness] = nodes['__witness_split_' + witness]["fn"].log_prob(nodes['__witness_split_' + witness]["value"]).squeeze().tolist()

    
    #values_table['cdif'] = nodes['consequent_differs_binary']["value"].squeeze().tolist()
    #values_table['clp'] = nodes['consequent_differs']["fn"].log_prob(nodes['consequent_differs']["value"]).squeeze().tolist()

    #if isinstance(values_table['clp'], float):
    #    values_df = pd.DataFrame([values_table])
    # else:
    #     values_df = pd.DataFrame(values_table)
    
    # values_df = pd.DataFrame(values_table)

    #summands_ant = ['alp_' + antecedent for antecedent in antecedents]
    #summands_wit = ['wlp_' + witness for witness in witness_candidates]
    #summands = [f"elp_{evaluated_node}"] +  summands_ant + summands_wit + ['clp']
    
    
    # values_df["int"] =  values_df.apply(lambda row: sum(row[row.index.str.startswith("apr_")] == 0), axis=1)
    # values_df['int'] = 1 - values_df[f"epr_{evaluated_node}"] + values_df["int"]
    # values_df["wpr"] = values_df.apply(lambda row: sum(row[row.index.str.startswith("wpr_")] == 1), axis=1)
    # values_df["changes"] =   values_df["int"] + values_df["wpr"]


    #values_df["sum_lp"] =  values_df[summands].sum(axis = 1) 
    # values_df.drop_duplicates(inplace = True)
    # values_df.sort_values(by = "sum_lp", inplace = True, ascending = False)

    # tab =  values_df.reset_index(drop = True)

    # tab = remove_redundant_rows(tab)

    tab = values_table

    #if round:
    #    tab = tab.round(3)

    return tab


get_table(stones_sallyHPR.trace.trace.nodes, "sally_throws", stones_sallyHPR.treatment_candidates, 
          stones_sallyHPR.witness_candidates)

{'obs_sally_throws': [1.0, 1.0, 1.0, 1.0, 1.0],
 'int_sally_throws': [1.0, 1.0, 1.0, 1.0, 0.0],
 'epr_sally_throws': [1, 1, 1, 1, 0],
 'elp_sally_throws': [-0.35953617095947266,
  -0.35953617095947266,
  -0.35953617095947266,
  -0.35953617095947266,
  -1.1973283290863037],
 'obs_bill_throws': [1.0, 1.0, 1.0, 1.0, 1.0],
 'int_bill_throws': [1.0, 0.0, 1.0, 1.0, 1.0],
 'apr_bill_throws': [1, 0, 1, 1, 1],
 'alp_bill_throws': [-0.3566749691963196,
  -1.2039728164672852,
  -0.3566749691963196,
  -0.3566749691963196,
  -0.3566749691963196],
 'obs_bill_hits': [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]],
 'wpr_bill_hits': [0, 1, 0, 0, 0],
 'wlp_bill_hits': [-0.3624056577682495,
  -1.1907275915145874,
  -0.3624056577682495,
  -0.3624056577682495,
  -0.3624056577682495]}

{'type': 'sample', 'name': 'prob_bill_hits', 'fn': Beta(), 'is_observed': True, 'args': (), 'kwargs': {}, 'value': tensor(1.), 'infer': {'_do_not_observe': True}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (CondIndepStackFrame(name='runs', dim=-1, size=5, counter=0),), 'done': True, 'stop': False, 'continuation': None}


{'obs_sally_throws': [1.0, 1.0, 1.0, 1.0, 1.0],
 'int_sally_throws': [1.0, 1.0, 1.0, 1.0, 0.0],
 'epr_sally_throws': [1, 1, 1, 1, 0],
 'elp_sally_throws': [-0.35953617095947266,
  -0.35953617095947266,
  -0.35953617095947266,
  -0.35953617095947266,
  -1.1973283290863037],
 'obs_bill_throws': [1.0, 1.0, 1.0, 1.0, 1.0],
 'int_bill_throws': [1.0, 1.0, 1.0, 0.0, 1.0],
 'apr_bill_throws': [1, 1, 1, 0, 1],
 'alp_bill_throws': [-0.3566749691963196,
  -0.3566749691963196,
  -0.3566749691963196,
  -1.2039728164672852,
  -0.3566749691963196],
 'wpr_bill_throws': [0, 0, 1, 1, 1],
 'wlp_bill_throws': [-0.5659106969833374,
  -0.5659106969833374,
  -0.8389658331871033,
  -0.8389658331871033,
  -0.8389658331871033]}

In [63]:
# minimal witness size becomes non-trivial here
# we only record different minimal difference-making scenarios

responsibility_check(stones_sallyHPR)

ValueError: All arrays must be of the same length

In [60]:
# following Halpern
# Sally's responsibility is 1/2

responsibility_stones_sally_HPR.responsibility

0.5

In [68]:

# Billy has degree of responsibility 0
# for the bottle shattering,
# as his throw is not a part of an actual cause

pyro.set_rng_seed(102)

responsibility_stones_bill_HPR = HalpernPearlResponsibilityApproximate(
    model = stones_model,
    nodes = stones_model.nodes,
    antecedent = "bill_throws", outcome = "bottle_shatters",
    observations = {"prob_sally_throws": 1, 
                    "prob_bill_throws": 1,
                    "prob_sally_hits": 1,
                    "prob_bill_hits": 1,
                    "prob_bottle_shatters_if_sally": 1,
                    "prob_bottle_shatters_if_bill": 1,
                    "sally_throws": 1, "bill_throws": 1},
                      runs_n=10)

In [66]:
pyro.set_rng_seed(101)
responsibility_stones_bill_HPR()

In [67]:
responsibility_stones_bill_HPR.responsibility

0

### Firing squad

There is a firing squad consisting of five excellent marksmen. Only one of them has a live bullet in his rifle and the rest have blanks. They shoot and the prisoner dies. The marksmen shoot at the prisoner and he dies. The only cause of the prisoner’s death is the marksman with the live bullet. That marksman has degree of responsibility 1 for the death and all the others have degree of responsibility 0. In the notebook on blame, TODO add link we will see that if the marksmen completely do not know which of them has the live bullet, blame is nevertheless equally distributed between them.

In [69]:

def firing_squad_model():
    probs = pyro.sample("probs", dist.Dirichlet(torch.ones(5)))

    who_has_bullet = pyro.sample("who_has_bullet", dist.OneHotCategorical(probs))

    mark0 = pyro.deterministic("mark0", torch.tensor([who[0] for who in who_has_bullet]), event_dim=0)
    mark1 = pyro.deterministic("mark1", torch.tensor([who[1] for who in who_has_bullet]), event_dim=0)
    mark2 = pyro.deterministic("mark2", torch.tensor([who[2] for who in who_has_bullet]), event_dim=0)
    mark3 = pyro.deterministic("mark3", torch.tensor([who[3] for who in who_has_bullet]), event_dim=0)
    mark4 = pyro.deterministic("mark4", torch.tensor([who[4] for who in who_has_bullet]), event_dim=0)

    dead = pyro.deterministic("dead", mark0 + mark1 + mark2 + mark3 + 
                                mark4  > 0)
    
    return {"probs": probs,
            "mark0": mark0,
            "mark1": mark1,
            "mark2": mark2,
            "mark3": mark3,
            "mark4": mark4, 
            "dead": dead}



In [70]:
pyro.set_rng_seed(102)

responsibility_loaded_HPR = HalpernPearlResponsibilityApproximate(
    model = firing_squad_model,
    nodes = ["mark" + str(i) for i in range(0,5)],
    antecedent = "mark0", outcome = "dead",
    observations = {"probs": torch.tensor([1., 0., 0., 0., 0.]),},
                      runs_n=50)


In [71]:
pyro.set_rng_seed(102)

responsibility_empty_HPR = HalpernPearlResponsibilityApproximate(
    model = firing_squad_model,
    nodes = ["mark" + str(i) for i in range(0,5)],
    antecedent = "mark1", outcome = "dead",
    observations = {"probs": torch.tensor([1., 0., 0., 0., 0.]),},
                      runs_n=50)

In [75]:
# If you have the live bullet

responsibility_loaded_HPR()
responsibility_loaded_HPR.responsibility

1.0

In [74]:
# if you have a blank,
# as we keep bullet's location constant in the model
# nothing can make a difference to mark1's contribution
# so his responsibility is zero

responsibility_empty_HPR()
responsibility_empty_HPR.responsibility

0