In [2]:
import functools

import numpy as np

import torch
from typing import Dict, List, Optional,  Union, Callable

import pandas as pd

import pyro
import pyro.distributions as dist

import random

from causal_pyro.indexed.ops import IndexSet, gather, indices_of, scatter
from causal_pyro.interventional.handlers import do
from causal_pyro.counterfactual.handlers import MultiWorldCounterfactual, Preemptions

In [7]:
class HalpernPearlModifiedApproximate:

    def __init__(
        self, 
        model: Callable,
        antecedents: Union[Dict[str, torch.Tensor], List[str]],
        outcome: str,
        witness_candidates: List[str],
        observations: Optional[Dict[str, torch.Tensor]],
        sample_size: int = 100,
        event_dim: int = 0
        ):
        
        self.model = model
        self.antecedents = antecedents
        self.outcome = outcome
        self.witness_candidates = witness_candidates
        self.observations = observations
        self.sample_size = sample_size

        self.antecedents_dict = (
            self.antecedents if isinstance(self.antecedents, dict)
            else self.revert_antecedents(self.antecedents)
        )
    
        self.preemptions = {candidate: functools.partial(self.preempt_with_factual,
                                             antecedents = self.antecedents) for 
                                             candidate in self.witness_candidates}
        

    @staticmethod
    def revert_antecedents(antecedents: List[str]) -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
        return {antecedent: (lambda v: 1 - v) for antecedent in antecedents}

    @staticmethod   
    def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
        if antecedents is None:
            antecedents = []

        antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

        factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                                event_dim=event_dim)
            
        return scatter({
            IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
            IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
        }, event_dim=event_dim)
        
        
    def __call__(self, *args, **kwargs):
        with pyro.poutine.trace() as trace:
            with MultiWorldCounterfactual():
                with do(actions=self.antecedents_dict):
                    with Preemptions(actions = self.preemptions):
                        with pyro.condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):
                            with pyro.plate("plate", self.sample_size):
                                self.consequent = self.model()[self.outcome]
                                self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.antecedents}))
                                self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.antecedents}))
                                self.consequent_differs = self.intervened_consequent != self.observed_consequent   
                                pyro.factor("consequent_differs", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))
                            
        self.trace = trace.trace

        # # slightly hacky solution for cases with no witness candidates
        self.existential_but_for = any(self.consequent_differs.squeeze().tolist()
                                        ) if self.witness_candidates else self.consequent_differs.squeeze()
        
        witness_dict = dict()
        if self.witness_candidates:
            witness_keys = ["__split_" + candidate for candidate in self.witness_candidates]
            witness_dict = {key: self.trace.nodes[key]['value']  for key in witness_keys}
            

        witness_dict['observed'] = self.observed_consequent.squeeze()
        witness_dict['intervened'] = self.intervened_consequent.squeeze()
        witness_dict['consequent_differs'] = self.consequent_differs.squeeze()

        # slightly hacky as above
        self.witness_df = pd.DataFrame(witness_dict) if self.witness_candidates else witness_dict

        if self.witness_candidates:
            self.witness_df['witness_size'] = self.witness_df[witness_keys].sum(axis = 1)
            satisfactory = self.witness_df[self.witness_df['consequent_differs'] == True]
            
        self.minimal_witness_size = satisfactory['witness_size'].min() if self.witness_candidates else 0
        self.responsibility = 1/(len(self.antecedents) + self.minimal_witness_size)


In [173]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)

    outcome = pyro.deterministic("outcome", vote0 + vote1 + vote2 + vote3 + vote4 + vote5 > 3)
    return {"outcome": outcome.float()}

voting_model()

{'outcome': tensor(0.)}

In [176]:
# if you're one of four voters who voted for, you are an actual cause
voting4HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ["vote0"],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,6)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=0., u_vote_5=0),
    sample_size = 1000)

voting4HPM()

print(
voting4HPM.existential_but_for
)

print(
voting4HPM.minimal_witness_size
)

print(voting4HPM.responsibility)


True
0
1.0


In [182]:
pyro.set_rng_seed(115)
sample_size = 2

model = voting_model
antecedents = ['vote0', 'vote1']
outcome = "outcome"
witness_candidates = [f"vote{i}" for i in range(2,6)]
observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                    u_vote3=1., u_vote4=1., u_vote_5=1.)

# with pyro.poutine.trace() as trace:
#             with MultiWorldCounterfactual():
#                 with do(actions=antecedents_dict):
#                     with Preemptions(actions = self.preemptions):
#                         with pyro.condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):
with pyro.plate("plate", sample_size):
    consequent = model()[outcome]
    intervened_consequent = gather(consequent, IndexSet(**{ant: {1} for ant in antecedents}))
    #observed_consequent = gather(consequent, IndexSet(**{ant: {0} for ant in antecedents}))
    #consequent_differs = intervened_consequent != observed_consequent   
    #pyro.factor("consequent_differs", torch.where(consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))

print(consequent)

KeyError: 'vote0'

In [175]:

pyro.set_rng_seed(122)
# if you're one of four voters who voted for, you are an actual cause
voting6HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ['vote0', 'vote1'],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(2,6)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote_5=1.),
    sample_size = 2)

voting6HPM()

print(
voting6HPM.existential_but_for
)

print(
voting6HPM.witness_df
)

print(
"v0:", voting6HPM.trace.nodes['vote0']['value'],
"v1:", voting6HPM.trace.nodes['vote1']['value'], 
"v2:", voting6HPM.trace.nodes['vote2']['value'],
"v3:", voting6HPM.trace.nodes['vote3']['value'],
"v4:", voting6HPM.trace.nodes['vote4']['value'],
"v5:", voting6HPM.trace.nodes['vote5']['value'],
"outcome:" , voting6HPM.trace.nodes['outcome']['value']
)

TypeError: 'bool' object is not iterable

In [122]:
#pyro.set_rng_seed(101)
voters = [f"vote{i}" for i in range(1,6,)]
companion_size = 1 #random.randint(0,len(voters))
companion_candidates = ["vote0"] +  random.sample(voters, companion_size)
witness_candidates = [f"vote{i}" for i in range(1,6) if f"vote{i}" not in companion_candidates]

voting_run_HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = companion_candidates,
    outcome = "outcome",
    witness_candidates = witness_candidates,
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote_5=1.),
    sample_size = 1000)

voting_run_HPM()

print(
voting_run_HPM.existential_but_for,

"antecedents:", voting_run_HPM.antecedents,

"antecedents_length:", len(voting_run_HPM.antecedents),


voting_run_HPM.minimal_witness_size,

)

True antecedents: ['vote0', 'vote1'] antecedents_length: 2 0


In [80]:
antecedent_sizes = []
existential_but_fors = []
minimal_witness_sizes = []
responsibilities = []

for step in range(1,15):

    voters = [f"vote{i}" for i in range(1,6,)]
    companion_size = random.randint(0,len(voters))
    companion_candidates = ["vote0"] +  random.sample(voters, companion_size)
    witness_candidates = [f"vote{i}" for i in range(1,6) if f"vote{i}" not in companion_candidates]
    antecedent_sizes.append(len(companion_candidates))

    voting_run_HPM = HalpernPearlModifiedApproximate(
        model = voting_model,
        antecedents = companion_candidates,
        outcome = "outcome",
        witness_candidates = witness_candidates,
        observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                            u_vote3=1., u_vote4=1., u_vote_5=1.),
        sample_size = 1000)

    voting_run_HPM()

    existential_but_fors.append(voting_run_HPM.existential_but_for)
    minimal_witness_sizes.append(voting_run_HPM.minimal_witness_size)
    responsibilities.append(voting_run_HPM.responsibility)

denumerators = [x + y for x, y in zip(antecedent_sizes, minimal_witness_sizes)]

responsibilityDF = pd.DataFrame({"existential_but_for": existential_but_fors, "antecedent_size": antecedent_sizes, "minimal_witness_size": minimal_witness_sizes,
                                 "denumerator": denumerators,
                                 "responsibility": responsibilities})




In [56]:

print(responsibilityDF)


    existential_but_for  antecedent_size  minimal_witness_size  responsibility
0                  True                2                   0.0        0.500000
1                 False                1                   NaN             NaN
2                  True                3                   0.0        0.333333
3                  True                2                   0.0        0.500000
4                  True                5                   0.0        0.200000
5                 False                1                   NaN             NaN
6                  True                3                   0.0        0.333333
7                  True                2                   0.0        0.500000
8                  True                2                   0.0        0.500000
9                  True                5                   0.0        0.200000
10                False                1                   NaN             NaN
11                 True                5            

In [None]:


voting4HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ["vote0"],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,6)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=0., u_vote_5=0),
    sample_size = 1000)

voting4HPM()

print(
voting4HPM.existential_but_for
)

