In [1]:
import functools

import numpy as np

import torch
from typing import Dict, List, Optional,  Union, Callable

import pandas as pd

import pyro
import pyro.distributions as dist

import random

from causal_pyro.indexed.ops import IndexSet, gather, indices_of, scatter
from causal_pyro.interventional.handlers import do
from causal_pyro.counterfactual.handlers import MultiWorldCounterfactual, Preemptions

## Implementing responsibility calculations

TODO: ADD LINK
We're using pretty much the same implementation as in the actual causality notebook, except for the last few lines, where we (1) keep track of witness set sizes, (2) select those for which the relative but-for clause holds, (3) find the minimum of such sizes, and (4) follow Halpern's definiton by dividing 1 by the sum of the ancedent list length with the minimal satisfactory witness size.

In [2]:
class HalpernPearlModifiedApproximate:

    def __init__(
        self, 
        model: Callable,
        antecedents: Union[Dict[str, torch.Tensor], List[str]],
        outcome: str,
        witness_candidates: List[str],
        observations: Optional[Dict[str, torch.Tensor]],
        sample_size: int = 100,
        event_dim: int = 0
        ):
        
        self.model = model
        self.antecedents = antecedents
        self.outcome = outcome
        self.witness_candidates = witness_candidates
        self.nodes = antecedents + [outcome] + witness_candidates
        self.observations = observations
        self.sample_size = sample_size

        self.antecedents_dict = (
            self.antecedents if isinstance(self.antecedents, dict)
            else self.revert_antecedents(self.antecedents)
        )
    
        self.preemptions = {candidate: functools.partial(self.preempt_with_factual,
                                             antecedents = self.antecedents) for 
                                             candidate in self.witness_candidates}
        

    @staticmethod
    def revert_antecedents(antecedents: List[str]) -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
        return {antecedent: (lambda v: 1 - v) for antecedent in antecedents}

    @staticmethod   
    def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
        if antecedents is None:
            antecedents = []

        antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

        factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                                event_dim=event_dim)
            
        return scatter({
            IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
            IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
        }, event_dim=event_dim)
        
        
    def __call__(self, *args, **kwargs):
        with pyro.poutine.trace() as trace:
            with MultiWorldCounterfactual():
                with do(actions=self.antecedents_dict):
                    with Preemptions(actions = self.preemptions):
                        with pyro.condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):
                            with pyro.plate("plate", self.sample_size):
                                self.consequent = self.model()[self.outcome]
                                self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.antecedents}))
                                self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.antecedents}))
                                self.consequent_differs = self.intervened_consequent != self.observed_consequent   
                                pyro.factor("consequent_differs", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))
                            
        self.trace = trace.trace
        self.nodes_trace = {node: self.trace.nodes[node]['value'] for node in self.nodes}
        
        

        # # slightly hacky solution for cases with no witness candidates
        self.existential_but_for = any(self.consequent_differs.squeeze().tolist()
                                        ) if self.witness_candidates else self.consequent_differs.squeeze()
      

        witness_dict = dict()
        if self.witness_candidates:
            witness_keys = ["__split_" + candidate for candidate in self.witness_candidates]
            witness_dict = {key: self.trace.nodes[key]['value']  for key in witness_keys}
            

        witness_dict['observed'] = self.observed_consequent.squeeze()
        witness_dict['intervened'] = self.intervened_consequent.squeeze()
        witness_dict['consequent_differs'] = self.consequent_differs.squeeze()

        # slightly hacky as above
        self.witness_df = pd.DataFrame(witness_dict) if self.witness_candidates else witness_dict

        if self.witness_candidates:
            self.witness_df['witness_size'] = self.witness_df[witness_keys].sum(axis = 1)
            satisfactory = self.witness_df[self.witness_df['consequent_differs'] == True]
            
        self.minimal_witness_size = satisfactory['witness_size'].min() if self.witness_candidates else 0
        self.responsibility = 1/(len(self.antecedents) + self.minimal_witness_size)


## Responsibility in voters scenarios

In [3]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))
    u_vote6 = pyro.sample("u_vote6", dist.Bernoulli(0.6))
    u_vote7 = pyro.sample("u_vote7", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)
    vote6 = pyro.deterministic("vote6", u_vote6, event_dim=0)
    vote7 = pyro.deterministic("vote7", u_vote7, event_dim=0)


    outcome = pyro.deterministic("outcome", vote0 + vote1 + vote2 + vote3 + 
                                 vote4 + vote5 + vote6 + vote7 > 4)
    return {"outcome": outcome.float()}

voting_model()

{'outcome': tensor(1.)}

In [4]:
# if you're one of five voters who voted for, you are an actual cause
# and your responsibility is 1 
voting5HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ["vote0"],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,8)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote5=0,
                        u_vote6=0., u_vote7=0.),
    sample_size = 1000)

voting5HPM()

print(
voting5HPM.existential_but_for
)

print(
voting5HPM.minimal_witness_size
)

print(voting5HPM.responsibility)


True
0
1.0


In [5]:
# if everyone voted for, you are not an actual cause
voting8HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ['vote0'],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,8)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote5=1.,
                        u_vote6=1., u_vote7=1.),
    sample_size = 1000)

voting8HPM()

print(
voting8HPM.existential_but_for
)

False


In [6]:
# but you're part of an actual cause

voting8_bHPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ['vote0', "vote1", "vote2", "vote3"],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(4,8)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote5=1.,
                        u_vote6=1., u_vote7=1.),
    sample_size = 1000)

voting8_bHPM()

print(
voting8_bHPM.existential_but_for
)

True


In [53]:
class HalpernPearlResponsibilityApproximate:
    
#(model, nodes, antecedent, outcome, observations, runs_n)

    def __init__(
        self, 
        model: Callable,
        nodes: List,
        antecedent: str,
        outcome: str,
        observations: Dict[str, torch.Tensor], 
        runs_n: int 
    ):
        self.model = model
        self.nodes = nodes
        self.antecedent = antecedent
        self.outcome = outcome
        self.observations = observations
        self.runs_n = runs_n
        
        self.minimal_antecedents_cache = []
        self.antecedent_sizes = []
        self.existential_but_fors = []
        self.minimal_witness_sizes = []
        self.responsibilities = []
        self.HPMs = []

    def __call__(self):
        
        for step in range(1,self.runs_n):

            nodes = self.nodes
            nodes.remove(self.outcome) if self.outcome in nodes
            companion_size = random.randint(0,len(nodes))
            companion_candidates = random.sample(self.nodes, companion_size)
            witness_candidates = [node for node in self.nodes if 
                                node != self.antecedent and 
                                node != self.outcome and 
                                    node not in companion_candidates]

            HPM = HalpernPearlModifiedApproximate(
                model = self.model,
                antecedents = companion_candidates,
                outcome = self.outcome,
                witness_candidates = witness_candidates,
                observations = self.observations,
                sample_size = 1000)
            
            HPM()

            self.HPMs.append(HPM)

            print("ebf", HPM.existential_but_for)

            if  True: # HPM.existential_but_for:


                subset_in_cache = any([s.issubset(set(HPM.antecedents)) for s in self.minimal_antecedents_cache])
                if not subset_in_cache:
                    self.minimal_antecedents_cache.append(set(HPM.antecedents))

                    if self.antecedent in HPM.antecedents:
                        self.antecedent_sizes.append(len(HPM.antecedents))
                        self.existential_but_fors.append(HPM.existential_but_for)
                        self.minimal_witness_sizes.append(HPM.minimal_witness_size)
                        self.responsibilities.append(HPM.responsibility)


        self.denumerators = [x + y for x, y in zip(self.antecedent_sizes, self.minimal_witness_sizes)]

        self.responsibilityDF = pd.DataFrame(
            {"existential_but_for": [bool(value) for value in self.existential_but_fors],
                "antecedent_size": self.antecedent_sizes, 
                "minimal_witness_size": self.minimal_witness_sizes,
                "denumerator": self.denumerators,
                "responsibility": self.responsibilities
            }
            )
        
        min_denumerator = min(self.responsibilityDF['denumerator'])
        self.responsibility = 1/min_denumerator

     
 

SyntaxError: expected 'else' after 'if' expression (2533124503.py, line 33)

In [8]:
everyone_voted_HPR = HalpernPearlResponsibilityApproximate(
    model = voting_model,
    nodes = [f"vote{i}" for i in range(0,8,)],
    antecedent = "vote0", outcome = "outcome",
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
    u_vote3=1., u_vote4=1., u_vote5=1., u_vote6 = 1., u_vote7 = 1.), 
    runs_n=500
    )

everyone_voted_HPR()

ebf True
exbf: True
ebf tensor(True)
exbf: tensor(True)
ebf False
exbf: False
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf False
exbf: False
ebf tensor(True)
exbf: tensor(True)
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf tensor(True)
exbf: tensor(True)
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf False
exbf: False
ebf 

In [9]:
everyone_voted_HPR.responsibilityDF


Unnamed: 0,existential_but_for,antecedent_size,minimal_witness_size,denumerator,responsibility
0,True,6,0.0,6.0,0.166667
1,False,3,,,
2,False,2,,,
3,False,2,,,
4,False,2,,,


In [10]:
# four people would need to change their votes
# to change the outcome
# so your responsibility is 1/4

everyone_voted_HPR.responsibility

0.16666666666666666

In [11]:
seven_voted_for_HPR = HalpernPearlResponsibilityApproximate(
    model = voting_model,
    nodes = [f"vote{i}" for i in range(0,8,)],
    antecedent = "vote0", outcome = "outcome",
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
    u_vote3=1., u_vote4=1., u_vote5=1., u_vote6 = 1., u_vote7 = 0.), 
    runs_n=500
    )

seven_voted_for_HPR()



ebf True
exbf: True
ebf True
exbf: True
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf tensor(True)
exbf: tensor(True)
ebf False
exbf: False
ebf False
exbf: False
ebf tensor(True)
exbf: tensor(True)
ebf False
exbf: False
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf False
exbf: False
ebf tensor(True)
exbf: tensor(True)
ebf False
exbf: False
ebf tensor(True)
exbf: tensor(True)
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf True
exbf: True
ebf tensor(True)
exbf: tensor(True)
ebf False
exbf: False
ebf True
exbf: True
ebf True
exbf: True
ebf True
exbf: True
ebf False
exbf: False
ebf False
exbf: False
ebf False
exbf: Fa

In [12]:
seven_voted_for_HPR.responsibilityDF

Unnamed: 0,existential_but_for,antecedent_size,minimal_witness_size,denumerator,responsibility
0,True,7,0,7,0.142857
1,True,4,0,4,0.25


In [13]:
# your responsibility is 1/3 as in this case
# it would be enough for three people to vote against
# to change the outcome

seven_voted_for_HPR.responsibility

0.25

## Responsibility in stone-throwing


This is worth looking at, as the causal structure is less trivial

In [14]:
def stones_model():        
    prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
    prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
    prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
    prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
    prob_bottle_shatters_if_sally = pyro.sample("prob_bottle_shatters_if_sally", dist.Beta(1, 1))
    prob_bottle_shatters_if_bill = pyro.sample("prob_bottle_shatters_if_bill", dist.Beta(1, 1))


    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))

    new_shp = torch.where(sally_throws == 1,prob_sally_hits , 0.0)

    sally_hits = pyro.sample("sally_hits",dist.Bernoulli(new_shp))

    new_bhp = torch.where(
            (
                bill_throws.bool()
                & (~sally_hits.bool())
            )
            == 1,
            prob_bill_hits,
            torch.tensor(0.0),
        )


    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))

    new_bsp = torch.where(
            bill_hits.bool() == 1,
            prob_bottle_shatters_if_bill,
            torch.where(
                sally_hits.bool() == 1,
                prob_bottle_shatters_if_sally,
                torch.tensor(0.0),
            ),
        )

    bottle_shatters = pyro.sample(
            "bottle_shatters", dist.Bernoulli(new_bsp)
        )

    return {
            "sally_throws": sally_throws,
            "bill_throws": bill_throws,
            "sally_hits": sally_hits,
            "bill_hits": bill_hits,
            "bottle_shatters": bottle_shatters,
        }

stones_model.nodes = [
            "sally_throws",
            "bill_throws",
            "sally_hits",
            "bill_hits",
            "bottle_shatters",
        ]

In [49]:
pyro.set_rng_seed(101)

# everyone voted for
responsibility_stones_sally_HPR = HalpernPearlResponsibilityApproximate(
    model = stones_model,
    nodes = stones_model.nodes,
    antecedent = "sally_throws", outcome = "bottle_shatters",
    observations = {"prob_sally_throws": 1, 
                    "prob_bill_throws": 1,
                    "prob_sally_hits": 1,
                    "prob_bill_hits": 1,
                    "prob_bottle_shatters_if_sally": 1,
                    "prob_bottle_shatters_if_bill": 1,
                    "sally_throws": 1, "bill_throws": 1},
                      runs_n=3)

In [51]:
pyro.set_rng_seed(101)
responsibility_stones_sally_HPR()

ValueError: list.remove(x): x not in list

In [44]:
hpm2 = responsibility_stones_sally_HPR.HPMs[1]

dir(hpm2)

#hpm2.existential_but_for
#hpm2.sample_size

hpm2.antecedents


['bottle_shatters', 'bill_throws', 'sally_hits', 'sally_throws', 'bill_hits']

In [32]:
responsibility_stones_sally_HPR.existential_but_fors

[True]

In [34]:
# following Halpern
# Suzy's responsibility is 1/2


0.3333333333333333

In [21]:
# Halpern says:
# Billy has degree of responsibility 0
#  for the bottle shattering, since his throw was not a cause of the outcome.

# But this argument doesn't work, as items that aren't causes can have responsibility 
# (see his own treatment of the voters case)

stonesHPM = HalpernPearlModifiedApproximate(
    model = stones_model,
    antecedents = ["sally_throws", "bill_throws"],
    outcome = "bottle_shatters",
    witness_candidates = ["bill_throws", "bill_hits"],
    observations = {"prob_sally_throws": 1, 
                    "prob_bill_throws": 1,
                    "prob_sally_hits": 1,
                    "prob_bill_hits": 1,
                    "prob_bottle_shatters_if_sally": 1,
                    "prob_bottle_shatters_if_bill": 1,
                    "sally_throws": 1, "bill_throws": 1},
    sample_size = 6,
    event_dim = 0
)

stonesHPM()

print(
stonesHPM.witness_df
)

print(stonesHPM.existential_but_for)

   __split_bill_throws  __split_bill_hits  observed  intervened  \
0                    1                  0       1.0         1.0   
1                    0                  0       1.0         0.0   
2                    1                  1       1.0         0.0   
3                    0                  0       1.0         0.0   
4                    0                  1       1.0         0.0   
5                    1                  0       1.0         1.0   

   consequent_differs  witness_size  
0               False             1  
1                True             0  
2                True             2  
3                True             0  
4                True             1  
5               False             1  
True


In [23]:

stones_responsibility_bill_DF = stones_responsibility(observations, "bill_throws", "bottle_shatters", 100)

stones_responsibility_bill_DF[stones_responsibility_bill_DF['existential_but_for']]

Unnamed: 0,existential_but_for,antecedent_size,minimal_witness_size,denumerator,responsibility
5,True,2,0.0,2.0,0.5
12,True,2,0.0,2.0,0.5
21,True,2,0.0,2.0,0.5
28,True,2,0.0,2.0,0.5
33,True,2,0.0,2.0,0.5
34,True,2,0.0,2.0,0.5
36,True,2,0.0,2.0,0.5
46,True,2,0.0,2.0,0.5
47,True,2,0.0,2.0,0.5
52,True,2,0.0,2.0,0.5


In [130]:

min_den_stones = min(stones_responsibility_bill_DF[stones_responsibility_bill_DF['existential_but_for']]['denumerator'])
1/min_den_stones


0.5

In [29]:
pyro.set_rng_seed(101)

stonesHPM = HalpernPearlModifiedApproximate(
    model = stones_model,
    antecedents = ["sally_throws"],
    outcome = "bottle_shatters",
    witness_candidates = ["bill_throws", "bill_hits"],
    observations = {"prob_sally_throws": 1, 
                    "prob_bill_throws": 1,
                    "prob_sally_hits": 1,
                    "prob_bill_hits": 1,
                    "prob_bottle_shatters_if_sally": 1,
                    "prob_bottle_shatters_if_bill": 1,
                    "sally_throws": 1, "bill_throws": 1},
    sample_size = 6,
    event_dim = 0
)

stonesHPM()

print(
stonesHPM.witness_df
)

print(stonesHPM.existential_but_for)

   __split_bill_throws  __split_bill_hits  observed  intervened  \
0                    1                  1       1.0         0.0   
1                    1                  1       1.0         0.0   
2                    0                  1       1.0         0.0   
3                    0                  1       1.0         0.0   
4                    0                  0       1.0         1.0   
5                    1                  0       1.0         1.0   

   consequent_differs  witness_size  
0                True             2  
1                True             2  
2                True             1  
3                True             1  
4               False             0  
5               False             1  
True
