In [1]:
%reload_ext autoreload
%autoreload 2
%pdb off

import functools

import torch
from typing import Dict, List, Optional, Tuple, Union, TypeVar, Callable

import pandas as pd

import pyro
import pyro.distributions as dist

from causal_pyro.indexed.ops import IndexSet, gather, indices_of, scatter
from causal_pyro.interventional.handlers import do
from causal_pyro.counterfactual.handlers import MultiWorldCounterfactual, Preemptions

Automatic pdb calling has been turned OFF


In [2]:
class HalpernPearlModifiedApproximate:

    def __init__(
        self, 
        model: Callable,
        antecedents: Union[Dict[str, torch.Tensor], List[str]],
        outcome: str,
        witness_candidates: List[str],
        observations: Optional[Dict[str, torch.Tensor]],
        sample_size: int = 100,
        event_dim: int = 0
        ):
        
        self.model = model
        self.antecedents = antecedents
        self.outcome = outcome
        self.witness_candidates = witness_candidates
        self.observations = observations
        self.sample_size = sample_size

        self.antecedents_dict = (
            self.antecedents if isinstance(self.antecedents, dict)
            else self.revert_antecedents(self.antecedents)
        )
    
        self.preemptions = {candidate: functools.partial(self.preempt_with_factual,
                                             antecedents = self.antecedents) for 
                                             candidate in self.witness_candidates}
        

    @staticmethod
    def revert_antecedents(antecedents: List[str]) -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
        return {antecedent: (lambda v: 1 - v) for antecedent in antecedents}

    @staticmethod   
    def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
        if antecedents is None:
            antecedents = []

        antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

        factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                                event_dim=event_dim)
            
        return scatter({
            IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
            IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
        }, event_dim=event_dim)
        
        
    def __call__(self, *args, **kwargs):
        with pyro.poutine.trace() as trace:
            with MultiWorldCounterfactual():
                with do(actions=self.antecedents_dict):
                    with Preemptions(actions = self.preemptions):
                        with pyro.condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):
                            with pyro.plate("plate", self.sample_size):
                                self.consequent = self.model()[self.outcome]
                                self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.antecedents}))
                                self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.antecedents}))
                                self.consequent_differs = self.intervened_consequent != self.observed_consequent   
                                pyro.factor("consequent_differs", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))
                            
        self.trace = trace.trace

        # # slightly hacky solution for cases with no witness candidates
        self.existential_but_for = any(self.consequent_differs.squeeze().tolist()
                                        ) if self.witness_candidates else self.consequent_differs.squeeze()
        
        witness_dict = dict()
        if self.witness_candidates:
            witness_keys = ["__split_" + candidate for candidate in self.witness_candidates]
            witness_dict = {key: self.trace.nodes[key]['value']  for key in witness_keys}
            witness_dict['observed'] = self.observed_consequent.squeeze()

        witness_dict['intervened'] = self.intervened_consequent.squeeze()
        witness_dict['consequent_differs'] = self.consequent_differs.squeeze()

        # slightly hacky as above
        self.witness_df = pd.DataFrame(witness_dict) if self.witness_candidates else witness_dict

    


## Stone throwing

In [4]:
def stones_model():        
    prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
    prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
    prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
    prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
    prob_bottle_shatters_if_sally = pyro.sample("prob_bottle_shatters_if_sally", dist.Beta(1, 1))
    prob_bottle_shatters_if_bill = pyro.sample("prob_bottle_shatters_if_bill", dist.Beta(1, 1))


    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))

    new_shp = torch.where(sally_throws == 1,prob_sally_hits , 0.0)

    sally_hits = pyro.sample("sally_hits",dist.Bernoulli(new_shp))

    new_bhp = torch.where(
            (
                bill_throws.bool()
                & (~sally_hits.bool())
            )
            == 1,
            prob_bill_hits,
            torch.tensor(0.0),
        )


    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))

    new_bsp = torch.where(
            bill_hits.bool() == 1,
            prob_bottle_shatters_if_bill,
            torch.where(
                sally_hits.bool() == 1,
                prob_bottle_shatters_if_sally,
                torch.tensor(0.0),
            ),
        )

    bottle_shatters = pyro.sample(
            "bottle_shatters", dist.Bernoulli(new_bsp)
        )

    return {
            "sally_throws": sally_throws,
            "bill_throws": bill_throws,
            "sally_hits": sally_hits,
            "bill_hits": bill_hits,
            "bottle_shatters": bottle_shatters,
        }

stones_model.nodes = [
            "sally_throws",
            "bill_throws",
            "sally_hits",
            "bill_hits",
            "bottle_shatters",
        ]

stones_model()

{'sally_throws': tensor(0.),
 'bill_throws': tensor(0.),
 'sally_hits': tensor(0.),
 'bill_hits': tensor(0.),
 'bottle_shatters': tensor(0.)}

In [6]:
pyro.set_rng_seed(101)
stonesHPM = HalpernPearlModifiedApproximate(
    model = stones_model,
    antecedents = ["sally_throws"],
    outcome = "bottle_shatters",
    witness_candidates = ["bill_throws", "bill_hits"],
    observations = {"prob_sally_throws": 1, 
                    "prob_bill_throws": 1,
                    "prob_sally_hits": 1,
                    "prob_bill_hits": 1,
                    "prob_bottle_shatters_if_sally": 1,
                    "prob_bottle_shatters_if_bill": 1,
                    "sally_throws": 1, "bill_throws": 1},
    sample_size = 6,
    event_dim = 0
)

stonesHPM()

print(
any(stonesHPM.consequent_differs.squeeze())
)

print(
stonesHPM.witness_df
)

print(stonesHPM.existential_but_for)

True
   __split_bill_throws  __split_bill_hits  observed  intervened  \
0                    1                  1       1.0         0.0   
1                    1                  1       1.0         0.0   
2                    0                  1       1.0         0.0   
3                    0                  1       1.0         0.0   
4                    0                  0       1.0         1.0   
5                    1                  0       1.0         1.0   

   consequent_differs  
0                True  
1                True  
2                True  
3                True  
4               False  
5               False  
True


## Forest fire

In [7]:
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", torch.logical_and(match_dropped, lightning), event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

ff_conjunctive()

{'match_dropped': tensor(0.),
 'lightning': tensor(0.),
 'forest_fire': tensor(0.)}

In [8]:
def ff_disjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", torch.logical_or(match_dropped, lightning), event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

ff_disjunctive()

{'match_dropped': tensor(0.),
 'lightning': tensor(1.),
 'forest_fire': tensor(1.)}

In [9]:
# In the conjunctive model 
# Each of the two causes is a but-for cause 
pyro.set_rng_seed(101)
ff_conjunctiveHPM = HalpernPearlModifiedApproximate(
    model = ff_conjunctive,
    antecedents = ["match_dropped"],
    outcome = "forest_fire",
    witness_candidates = ["lightning"],
    observations = {"match_dropped": 1, "lightning": 1},
    sample_size = 4,
    event_dim = 0
)

ff_conjunctiveHPM()


print(
ff_conjunctiveHPM.witness_df
)


print(
ff_conjunctiveHPM.existential_but_for
)



   __split_lightning  observed  intervened  consequent_differs
0                  0       1.0         0.0                True
1                  1       1.0         0.0                True
2                  1       1.0         0.0                True
3                  0       1.0         0.0                True
True


In [35]:
# In the disjunctive model 
# either in the observed setting or in the preempted setting of MD
# there still would be fire if there was no lightning
pyro.set_rng_seed(101)
ff_disjunctiveHPM = HalpernPearlModifiedApproximate(
    model = ff_disjunctive,
    antecedents = ["match_dropped"],
    outcome = "forest_fire",
    witness_candidates = ["lightning"],
    observations = {"match_dropped": 1, "lightning": 1},
    sample_size = 4,
    event_dim = 0
)

ff_disjunctiveHPM()

print(
ff_disjunctiveHPM.witness_df
)


any(ff_disjunctiveHPM.consequent_differs.squeeze().tolist())

ff_disjunctiveHPM.existential_but_for




   __split_lightning  observed  intervened  consequent_differs
0                  0       1.0         1.0               False
1                  1       1.0         1.0               False
2                  1       1.0         1.0               False
3                  0       1.0         1.0               False


False

In [31]:
pyro.set_rng_seed(101)
ff_disjunctive_jointHPM = HalpernPearlModifiedApproximate(
    model = ff_disjunctive,
    antecedents = ["match_dropped", "lightning"],
    outcome = "forest_fire",
    witness_candidates = [],
    observations = {"match_dropped": 1, "lightning": 1},
    sample_size = 4,
    event_dim = 0
)

ff_disjunctive_jointHPM()


print(
ff_disjunctive_jointHPM.existential_but_for
)


print(
ff_disjunctive_jointHPM.witness_df
)



tensor(True)
{'intervened': tensor(0.), 'consequent_differs': tensor(True)}


## Intransitivity of actual causality

In [35]:
def bc_function(mt, tt):
    condition1 = (mt == 1) & (tt == 1)
    condition2 = (mt == 1) & (tt == 0)
    condition3 = (mt == 0) & (tt == 1)
    condition4 = ~(condition1 | condition2 | condition3)

    output = torch.where(condition1, torch.tensor(3.0), torch.tensor(0.0))
    output = torch.where(condition2, torch.tensor(0.0), output)
    output = torch.where(condition3, torch.tensor(1.0), output)
    output = torch.where(condition4, torch.tensor(2.0), output)

    return output


def model_doctors():
    u_monday_treatment = pyro.sample("u_monday_treatment", dist.Bernoulli(0.5))

    monday_treatment = pyro.deterministic(
        "monday_treatment", u_monday_treatment, event_dim=0
    )

    tuesday_treatment = pyro.deterministic(
        "tuesday_treatment",
        torch.logical_not(monday_treatment).float(),
        event_dim=0,
    )

    bills_condition = pyro.deterministic(
        "bills_condition",
        bc_function(monday_treatment, tuesday_treatment),
        event_dim=0,
    )

    bill_alive = pyro.deterministic(
        "bill_alive", bills_condition.not_equal(3.0).float(), event_dim=0
    )

    return {
        "monday_treatment": monday_treatment,
        "tuesday_treatment": tuesday_treatment,
        "bills_condition": bills_condition,
        "bill_alive": bill_alive,
    }


# scm_doctors.ranges = {
#     "monday_treatment": torch.Tensor([0.0, 1.0]),
#     "tuesday_treatment": torch.Tensor([0.0, 1.0]),
#     "bills_condition": torch.Tensor([0.0, 1.0, 2.0, 3.0]),
#     "bill_alive": torch.Tensor([0.0, 1.0]),
# }

model_doctors()


{'monday_treatment': tensor(1.),
 'tuesday_treatment': tensor(0.),
 'bills_condition': tensor(0.),
 'bill_alive': tensor(1.)}

In [43]:
doctors1_HPM = HalpernPearlModifiedApproximate(
    model = model_doctors,
    antecedents = ["monday_treatment"],
    outcome = "tuesday_treatment",
    witness_candidates = [],
    observations = {"u_monday_treatment": 1},
    sample_size = 4,
    event_dim = 0
)

doctors1_HPM()


doctors2_HPM = HalpernPearlModifiedApproximate(
    model = model_doctors,
    antecedents = ["tuesday_treatment"],
    outcome = "bill_alive",
    witness_candidates = [],
    observations = {"u_monday_treatment": 1},
    sample_size = 4,
    event_dim = 0
)

doctors2_HPM()

doctors3_HPM = HalpernPearlModifiedApproximate(
    model = model_doctors,
    antecedents = ["monday_treatment"],
    outcome = "bill_alive",
    witness_candidates = [],
    observations = {"u_monday_treatment": 1},
    sample_size = 4,
    event_dim = 0
)

doctors3_HPM()


print(
"step 1:", doctors1_HPM.existential_but_for,
"step 2:",  doctors2_HPM.existential_but_for,
"step 3:", doctors3_HPM.existential_but_for
)

step 1: tensor(True) step 2: tensor(True) step 3: tensor(False)


## Friendly fire incident

This comes from a causal model developed in a real-life incident investgation, as discussed in the Causalis Incident Reporting using SERAS® Reporter and SERAS® Analyst.

a U.S. Special Forces air controller changing the battery on a Global Positioning System device he was using to target a Taliban outpost north of Kandahar.  Three special forces soldiers were killed and 20 were injured when a 2,000-pound, satellite-guided bomb landed, not on the Taliban outpost, but on a battalion command post occupied by American forces and a group of Afghan allies, including Hamid Karzai, now the interim prime minister.The Air Force combat controller was using a Precision Lightweight GPS Receiver to calculate the Taliban's coordinate for the attack. The controller did not realise that after he changed the device's battery, the machine was programmed to automatically come back on displaying coordinates for its own location, the official said.

Minutes before the B-52 strike, the controller had used the GPS receiver to
calculate the latitude and longitude of the Taliban position in minutes and seconds for an airstrike by a Navy F/A-18. Then, with the B-52 approaching the target, the air controller did a second calculation in “degree decimals” required by the bomber crew.  The controller had performed the calculation and recorded the position, when the receiver battery died. Without realizing the machine was programmed to come back on showing the coordinates of its
own location, the controller mistakenly called in the American position to the B-52.

Factors included in the model:

1. The air controller changed the battery on the PLGR
2. Three special forces soldiers were killed and 20 were injured
3. B-52 fired a JDAM bomb at the Allied position
4. The air controller was using the PLGR to calculate the Taliban's coordinates
5. The controller did not realize that the PLGR was programmed to automatically come back on displaying coordinates for its own location
6. The controller had used the PLGR to calculate the latitude and longitude of the Taliban position in minutes and seconds for an airstrike by a Navy F/A-18
7. The air controller did a second calculation in “degree decimals” required by the bomber crew
8. The controller had performed the calculation and recorded the position
9. The controller mistakenly called in the American position to the B-52
10. The B-52 fired a JDAM bomb at the Allied position
11. The U.S. Air Force and Army had a training problem
12. The PLRG resumed displaying the coordinates of its own location after the battery was changed
13. The battery died at the crucial time
14. The controller though he was calling in the Taliban position

The DAG used in the model is as follows:
![Friendly Fire DAG](figures/friendly_fire_dag.png)

In [50]:

def model_friendly_fire():
    u_f4_PLGR_now = pyro.sample("u_f4_PLGR_now", dist.Bernoulli(0.5))
    u_f11_training = pyro.sample("u_f11_training", dist.Bernoulli(0.5))

    f4_PLGR_now = pyro.deterministic("f4_PLGR_now", u_f4_PLGR_now, event_dim=0)
    f11_training = pyro.deterministic(
        "f11_training", u_f11_training, event_dim=0
    )

    f6_PLGR_before = pyro.deterministic(
        "f6_PLGR_before", f4_PLGR_now, event_dim=0
    )
    f7_second_calculation = pyro.deterministic(
        "f7_second_calculation", f4_PLGR_now, event_dim=0
    )
    f13_battery_died = pyro.deterministic(
        "f13_battery_died",
        f6_PLGR_before.bool() & f7_second_calculation.bool(),
        event_dim=0,
    )

    f1_battery_change = pyro.deterministic(
        "f1_battery_change", f13_battery_died, event_dim=0
    )

    f12_PLGR_after = pyro.deterministic(
        "f12_PLGR_after", f1_battery_change, event_dim=0
    )

    f5_unaware = pyro.deterministic("f5_unaware", f11_training, event_dim=0)

    f14_wrong_position = pyro.deterministic(
        "f14_wrong_position", f5_unaware, event_dim=0
    )

    f9_mistake_call = pyro.deterministic(
        "f9_mistake_call",
            f12_PLGR_after.bool() & 
            f14_wrong_position.bool(),
        event_dim=0,
    )

    f3_fired = pyro.deterministic("f3_fired", f9_mistake_call, event_dim=0)

    f10_landed = pyro.deterministic(
        "f10_landed", f3_fired.bool() &  f9_mistake_call.bool(), event_dim=0
    )

    f2_killed = pyro.deterministic("f2_killed", f10_landed, event_dim=0)

    return {
        "f1_battery_change": f1_battery_change,
        "f2_killed": f2_killed,
        "f3_fired": f3_fired,
        "f4_PLGR_now": f4_PLGR_now,
        "f5_unaware": f5_unaware,
        "f6_PLGR_before": f6_PLGR_before,
        "f7_second_calculation": f7_second_calculation,
        "f9_mistake_call": f9_mistake_call,
        "f10_landed": f10_landed,
        "f11_training": f11_training,
        "f12_PLGR_after": f12_PLGR_after,
        "f13_battery_died": f13_battery_died,
        "f14_wrong_position": f14_wrong_position,
    }

model_friendly_fire()


{'f1_battery_change': tensor(True),
 'f2_killed': tensor(False),
 'f3_fired': tensor(False),
 'f4_PLGR_now': tensor(1.),
 'f5_unaware': tensor(0.),
 'f6_PLGR_before': tensor(1.),
 'f7_second_calculation': tensor(1.),
 'f9_mistake_call': tensor(False),
 'f10_landed': tensor(False),
 'f11_training': tensor(0.),
 'f12_PLGR_after': tensor(True),
 'f13_battery_died': tensor(True),
 'f14_wrong_position': tensor(0.)}

In [57]:
friendly_fire_HPM = HalpernPearlModifiedApproximate(
    model = model_friendly_fire,
    antecedents = ["f6_PLGR_before", "f7_second_calculation"],
    outcome = "f2_killed",
    witness_candidates = ["f4_PLGR_now","f5_unaware",
    "f11_training",
    "f14_wrong_position"],
    observations = {"u_f4_PLGR_now": 1.0, "u_f11_training": 1.0},
    sample_size = 20,
    event_dim = 0
)

friendly_fire_sub1_HPM = HalpernPearlModifiedApproximate(
    model = model_friendly_fire,
    antecedents = ["f6_PLGR_before"],
    outcome = "f2_killed",
    witness_candidates = ["f4_PLGR_now","f5_unaware",
    "f11_training",
    "f14_wrong_position", "f7_second_calculation"],
    observations = {"u_f4_PLGR_now": 1.0, "u_f11_training": 1.0},
    sample_size = 20,
    event_dim = 0
)

friendly_fire_sub2_HPM = HalpernPearlModifiedApproximate(
    model = model_friendly_fire,
    antecedents = ["f7_second_calculation"],
    outcome = "f2_killed",
    witness_candidates = ["f4_PLGR_now","f5_unaware",
    "f11_training",
    "f14_wrong_position", "f6_PLGR_before"],
    observations = {"u_f4_PLGR_now": 1.0, "u_f11_training": 1.0},
    sample_size = 20,
    event_dim = 0
)

friendly_fire_HPM()
friendly_fire_sub1_HPM()
friendly_fire_sub2_HPM()

print(
"tuple: ", friendly_fire_HPM.existential_but_for,
"PLGR_before: ", friendly_fire_sub1_HPM.existential_but_for,
"second calculation: ", friendly_fire_sub2_HPM.existential_but_for
)


tuple:  True PLGR_before:  True second calculation:  True


## Voting


In [52]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))
    
    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)

    return {"outcome": vote0 + vote1 + vote2 + vote3 + vote4 +
            vote5  > 3}

voting_model()

{'outcome': tensor(True)}

In [63]:
pyro.set_rng_seed(32)
votingHPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ["vote0"],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,6)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=0., u_vote4=0., uvote_5=0,
                        ),
    sample_size = 10)

votingHPM()

print(
votingHPM.existential_but_for
)

print(
votingHPM.witness_df
)

print(votingHPM.trace.nodes.keys())

print(votingHPM.trace.nodes['__split_vote1']['value'])

True
   __split_vote1  __split_vote2  __split_vote3  __split_vote4  __split_vote5  \
0              0              1              0              1              1   
1              0              0              0              1              1   
2              0              0              1              0              0   
3              1              1              1              0              1   
4              0              1              1              0              0   
5              1              0              0              0              0   
6              0              1              1              0              0   
7              1              1              1              0              1   
8              0              0              1              1              0   
9              0              1              1              1              1   

   observed  intervened  consequent_differs  
0     False       False               False  
1      True       Fals

In [50]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))
    u_vote6 = pyro.sample("u_vote6", dist.Bernoulli(0.6))
    u_vote7 = pyro.sample("u_vote7", dist.Bernoulli(0.6))
    u_vote8 = pyro.sample("u_vote8", dist.Bernoulli(0.6))
    u_vote9 = pyro.sample("u_vote9", dist.Bernoulli(0.6))
    u_vote10 = pyro.sample("u_vote10", dist.Bernoulli(0.6))
    
    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)
    vote6 = pyro.deterministic("vote6", u_vote6, event_dim=0)
    vote7 = pyro.deterministic("vote7", u_vote7, event_dim=0)
    vote8 = pyro.deterministic("vote8", u_vote8, event_dim=0)
    vote9 = pyro.deterministic("vote9", u_vote9, event_dim=0)
    vote10 = pyro.deterministic("vote10", u_vote10, event_dim=0)

    return {"outcome": vote0 + vote1 + vote2 + vote3 + vote4 +
            vote5 + vote6 + vote7 + vote8 + vote9 + vote10 > 5}

voting_model()

{'outcome': tensor(True)}

In [51]:
# if you're one of six who voted for, you are an actual cause
voting6HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ["vote0"],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,11)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., uvote_5=1,
                         u_vote6=0., u_vote7=0., u_vote8=0.,
                        u_vote9=0., u_vote10=0.),
    sample_size = 1000)

voting6HPM()

print(
voting6HPM.existential_but_for
)


True


In [34]:
# if you're one of seven who voted for, you are an actual cause
voting7HPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ["vote0"],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in range(1,11)],
    observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., uvote_5=0,
                         u_vote6=0., u_vote7=0., u_vote8=0.,
                        u_vote9=0., u_vote10=0.),
    sample_size = 10)

voting7HPM()

print(
voting7HPM.existential_but_for
)

print(
    voting7HPM.witness_df
)


True
   __split_vote1  __split_vote2  __split_vote3  __split_vote4  __split_vote5  \
0              1              0              0              0              1   
1              0              1              0              1              0   
2              1              1              0              1              1   
3              1              1              1              1              1   
4              1              0              0              0              1   
5              0              1              0              0              1   
6              1              0              1              0              0   
7              1              1              0              1              1   
8              0              0              0              0              1   
9              1              1              0              0              1   

   __split_vote6  __split_vote7  __split_vote8  __split_vote9  __split_vote10  \
0              1              1  

In [60]:
def voting_model_7_4():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    return {"outcome": vote0 + vote1 + vote2 + vote3 + vote4 >= 3}

True


Unnamed: 0,__split_vote1,__split_vote2,__split_vote3,__split_vote4,observed,intervened,consequent_differs
0,1,0,0,0,True,False,True
1,1,1,1,0,True,False,True
2,1,1,1,1,True,False,True
3,0,1,0,1,True,False,True
4,0,0,0,0,True,False,True
...,...,...,...,...,...,...,...
995,1,0,1,0,True,False,True
996,0,1,0,0,True,False,True
997,0,0,0,0,True,False,True
998,1,1,1,0,True,False,True


In [None]:
#test