In [1]:
import functools
import contextlib
import collections
from typing import Callable, Iterable, TypeVar, Mapping, List, Dict

import pyro
import torch  # noqa: F401

from chirho.counterfactual.handlers.selection import get_factual_indices
from chirho.indexed.ops import IndexSet, cond, gather, indices_of, scatter

S = TypeVar("S")
T = TypeVar("T")

import pyro
import chirho
import pyro.distributions as dist
import pyro.infer
import pytest
import torch
import pandas as pd

from chirho.counterfactual.handlers.counterfactual import (MultiWorldCounterfactual,
        Preemptions)
from chirho.counterfactual.handlers.explanation import (
    SearchForCause,
    consequent_differs,
    random_intervention,
    undo_split,
    uniform_proposal,
)
from chirho.counterfactual.ops import preempt, split
from chirho.indexed.ops import IndexSet, gather, indices_of
from chirho.observational.handlers.condition import Factors, condition
from chirho.interventional.ops import Intervention, intervene
from chirho.interventional.handlers import do

In [2]:
@contextlib.contextmanager
def ExplainCauses(
    antecedents: Mapping[str, Intervention[T]]
    | Mapping[str, pyro.distributions.constraints.Constraint],
    witnesses: Mapping[str, Intervention[T]] | Iterable[str],
    consequents: Mapping[str, Callable[[T], float | torch.Tensor]]
    | Iterable[str],
    *,
    antecedent_bias: float = 0.0,
    witness_bias: float = 0.0,
    consequent_eps: float = -1e8,
    antecedent_prefix: str = "__antecedent_",
    witness_prefix: str = "__witness_",
    consequent_prefix: str = "__consequent_",
):
    """
    Effect handler for causal explanation.

    :param antecedents: A mapping from antecedent names to interventions.
    :param witnesses: A mapping from witness names to interventions.
    :param consequents: A mapping from consequent names to factor functions.
    """
    if isinstance(
        next(iter(antecedents.values())),
        pyro.distributions.constraints.Constraint,
    ):
        antecedents = {
            a: random_intervention(s, name=f"{antecedent_prefix}_proposal_{a}")
            for a, s in antecedents.items()
        }

    if not isinstance(witnesses, collections.abc.Mapping):
        witnesses = {
            w: undo_split(antecedents=list(antecedents.keys()))
            for w in witnesses
        }

    if not isinstance(consequents, collections.abc.Mapping):
        consequents = {
            c: consequent_differs(
                antecedents=list(antecedents.keys()), eps=consequent_eps
            )
            for c in consequents
        }

    if len(consequents) == 0:
        raise ValueError("must have at least one consequent")

    if len(antecedents) == 0:
        raise ValueError("must have at least one antecedent")

    if set(consequents.keys()) & set(antecedents.keys()):
        raise ValueError(
            "consequents and possible antecedents must be disjoint"
        )

    if set(consequents.keys()) & set(witnesses.keys()):
        raise ValueError("consequents and possible witnesses must be disjoint")

    antecedent_handler = SearchForCause(
        actions=antecedents, bias=antecedent_bias, prefix=antecedent_prefix
    )
    witness_handler = Preemptions(
        actions=witnesses, bias=witness_bias, prefix=witness_prefix
    )
    consequent_handler = Factors(factors=consequents, prefix=consequent_prefix)

    with antecedent_handler, witness_handler, consequent_handler:
        with pyro.poutine.trace() as logging_tr:
            yield

In [3]:
def gather_observed(value, antecedents, witnesses):
    _indices = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(value, event_dim=0)
        ]
    _int_can = gather(
    value, IndexSet(**{i: {0} for i in _indices}), event_dim=0,)
    return _int_can

def gather_intervened(value, antecedents, witnesses):
    _indices = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(value, event_dim=0)
        ]
    _int_can = gather(
    value, IndexSet(**{i: {1} for i in _indices}), event_dim=0,)
    return _int_can


def get_table(trace, mwc, antecedents, witnesses, consequents):

    values_table = {}
    nodes = trace.trace.nodes

    with mwc:

        for antecedent_str in antecedents.keys():
                
            obs_ant = gather_observed(nodes[antecedent_str]["value"], antecedents, witnesses)
            int_ant = gather_intervened(nodes[antecedent_str]["value"], antecedents, witnesses)

            values_table[f"{antecedent_str}_obs"] = obs_ant.squeeze().tolist()
            values_table[f"{antecedent_str}_int"] = int_ant.squeeze().tolist()
            
            apr_ant = nodes[f"__antecedent_{antecedent_str}"]["value"]
            values_table[f"apr_{antecedent_str}"] = apr_ant.squeeze().tolist()
            
            values_table[f"apr_{antecedent_str}_lp"] = nodes[f"__antecedent_{antecedent_str}"]["fn"].log_prob(apr_ant)

        if witnesses:
            for candidate in witnesses:
                obs_candidate = gather_observed(nodes[candidate]["value"], antecedents, witnesses)
                int_candidate = gather_intervened(nodes[candidate]["value"], antecedents, witnesses)
                values_table[f"{candidate}_obs"] = obs_candidate.squeeze().tolist()
                values_table[f"{candidate}_int"] = int_candidate.squeeze().tolist()

                wpr_con = nodes[f"__witness_{candidate}"]["value"]
                values_table[f"wpr_{candidate}"] = wpr_con.squeeze().tolist()
            

        for consequent in consequents:
            obs_consequent = gather_observed(nodes[consequent]["value"], antecedents, witnesses)
            int_consequent = gather_intervened(nodes[consequent]["value"], antecedents, witnesses)
            con_lp = nodes[f"__consequent_{consequent}"]['fn'].log_prob(torch.tensor(1)) #TODO: this feels like a hack
            _indices_lp = [
            i for i in list(antecedents.keys()) + witnesses if i in indices_of(con_lp)]
            int_con_lp = gather(con_lp, IndexSet(**{i: {1} for i in _indices_lp}), event_dim=0,)      


            values_table[f"{consequent}_obs"] = obs_consequent.squeeze().tolist()   
            values_table[f"{consequent}_int"] = int_consequent.squeeze().tolist()
            values_table[f"{consequent}_lp"] = int_con_lp.squeeze().tolist()   

    values_df = pd.DataFrame(values_table)

    values_df.drop_duplicates(inplace=True)

    summands = [col for col in values_df.columns if col.endswith('lp')]
    values_df["sum_log_prob"] =  values_df[summands].sum(axis = 1)
    values_df.sort_values(by = "sum_log_prob", ascending = False, inplace = True)

    return values_df

In [4]:
# this reduces the actual causality check to checking a property of the resulting sums of log probabilities
# for the antecedent preemption and the consequent differs nodes

def ac_check(trace, mwc, antecedents, witnesses, consequents):

     table = get_table(trace, mwc, antecedents, witnesses, consequents)
     
     if (list(table['sum_log_prob'])[0]<= -1e8):
          print("No resulting difference to the consequent in the sample.")
          return
     
     winners = table[table['sum_log_prob'] == table['sum_log_prob'].max()]
     

     ac_flags = []
     for index, row in winners.iterrows():
          active_antecedents = []
          for antecedent in antecedents:
               if row[f"apr_{antecedent}"] == 0:
                    active_antecedents.append(antecedent)

          ac_flags.append(set(active_antecedents) == set(antecedents))

     if not any(ac_flags):
          print("The antecedent set is not minimal.")
     else:
          print("The antecedent set is an actual cause.")

     return any(ac_flags)

### Stone-throwing

In [5]:
@pyro.infer.config_enumerate
def stones_model():        
    prob_sally_throws = pyro.sample("prob_sally_throws", dist.Beta(1, 1))
    prob_bill_throws = pyro.sample("prob_bill_throws", dist.Beta(1, 1))
    prob_sally_hits = pyro.sample("prob_sally_hits", dist.Beta(1, 1))
    prob_bill_hits = pyro.sample("prob_bill_hits", dist.Beta(1, 1))
    prob_bottle_shatters_if_sally = pyro.sample("prob_bottle_shatters_if_sally", dist.Beta(1, 1))
    prob_bottle_shatters_if_bill = pyro.sample("prob_bottle_shatters_if_bill", dist.Beta(1, 1))

    sally_throws = pyro.sample("sally_throws", dist.Bernoulli(prob_sally_throws))
    bill_throws = pyro.sample("bill_throws", dist.Bernoulli(prob_bill_throws))


    new_shp = torch.where(sally_throws == 1,prob_sally_hits, 0.0)

    sally_hits = pyro.sample("sally_hits",dist.Bernoulli(new_shp))

    new_bhp = torch.where(
        bill_throws.bool() & (~sally_hits.bool()),
        prob_bill_hits,
        torch.tensor(0.0),
    )

    bill_hits = pyro.sample("bill_hits", dist.Bernoulli(new_bhp))
    

    new_bsp = torch.where(bill_hits.bool(), prob_bottle_shatters_if_bill,
            torch.where(sally_hits.bool(),prob_bottle_shatters_if_sally,torch.tensor(0.0),),)

    bottle_shatters = pyro.sample("bottle_shatters", dist.Bernoulli(new_bsp))

    return {"sally_throws": sally_throws, "bill_throws": bill_throws,  "sally_hits": sally_hits,
            "bill_hits": bill_hits,  "bottle_shatters": bottle_shatters,}

stones_model.nodes = ["sally_throws","bill_throws", "sally_hits", "bill_hits","bottle_shatters",]

In [6]:
def tensorize_observations(observations):
    return {k: torch.as_tensor(v) for k, v in observations.items()}

observations = {"prob_sally_throws": 1.0, 
                "prob_bill_throws": 1.0,
                "prob_sally_hits": 1.0,
                "prob_bill_hits": 1.0,
                "prob_bottle_shatters_if_sally": 1.0,
                "prob_bottle_shatters_if_bill": 1.0,
                "sally_throws": 1.0, "bill_throws": 1.0}

observations_tensorized = tensorize_observations(observations)

antecedents = {"sally_throws": 0.0}
antencedent_bias = 0.1
witnesses = ["bill_throws", "bill_hits"]
consequents = ["bottle_shatters"]

#TODO? fails silently when consequents is a string

In [7]:
with MultiWorldCounterfactual() as mwc:
    with ExplainCauses(antecedents = antecedents, antecedent_bias= antencedent_bias,
                        witnesses = witnesses,
                        consequents = consequents):
        with condition(data = observations_tensorized):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr:
                    stones_model()



In [8]:
stones_table = get_table(tr, mwc, antecedents, witnesses, consequents)
display(stones_table)

Unnamed: 0,sally_throws_obs,sally_throws_int,apr_sally_throws,apr_sally_throws_lp,bill_throws_obs,bill_throws_int,wpr_bill_throws,bill_hits_obs,bill_hits_int,wpr_bill_hits,bottle_shatters_obs,bottle_shatters_int,bottle_shatters_lp,sum_log_prob
3,1.0,0.0,0,-0.916291,1.0,1.0,1,0.0,0.0,1,1.0,0.0,0.0,-0.9162907
4,1.0,0.0,0,-0.916291,1.0,1.0,0,0.0,0.0,1,1.0,0.0,0.0,-0.9162907
0,1.0,1.0,1,-0.510826,1.0,1.0,1,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
2,1.0,1.0,1,-0.510826,1.0,1.0,1,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
8,1.0,1.0,1,-0.510826,1.0,1.0,0,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
13,1.0,1.0,1,-0.510826,1.0,1.0,0,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
5,1.0,0.0,0,-0.916291,1.0,1.0,0,0.0,1.0,0,1.0,1.0,-100000000.0,-100000000.0
57,1.0,0.0,0,-0.916291,1.0,1.0,1,0.0,1.0,0,1.0,1.0,-100000000.0,-100000000.0


In [9]:
ac_check(tr, mwc, antecedents, witnesses, consequents)

The antecedent set is an actual cause.


True

In [10]:
# this atnecedent set is not minimal
antecedents2 = {"sally_throws": 0.0, "bill_throws": 0.0}
witnesses2 = ["bill_hits"]  #we know 'bill_throws' is not downstream from any split

with MultiWorldCounterfactual() as mwc2:
    with ExplainCauses(antecedents = antecedents2, antecedent_bias= antencedent_bias,
                        witnesses = witnesses2,
                        consequents = consequents):
        with condition(data = observations_tensorized):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr2:
                    stones_model()

In [11]:
stones_table2 = get_table(tr2, mwc2, antecedents2, witnesses2, consequents)
display(stones_table2)

Unnamed: 0,sally_throws_obs,sally_throws_int,apr_sally_throws,apr_sally_throws_lp,bill_throws_obs,bill_throws_int,apr_bill_throws,apr_bill_throws_lp,bill_hits_obs,bill_hits_int,wpr_bill_hits,bottle_shatters_obs,bottle_shatters_int,bottle_shatters_lp,sum_log_prob
6,1.0,0.0,0,-0.916291,1.0,1.0,1,-0.510826,0.0,0.0,1,1.0,0.0,0.0,-1.427116
3,1.0,0.0,0,-0.916291,1.0,0.0,0,-0.916291,0.0,0.0,0,1.0,0.0,0.0,-1.832581
24,1.0,0.0,0,-0.916291,1.0,0.0,0,-0.916291,0.0,0.0,1,1.0,0.0,0.0,-1.832581
0,1.0,1.0,1,-0.510826,1.0,1.0,1,-0.510826,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0
7,1.0,1.0,1,-0.510826,1.0,1.0,1,-0.510826,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
1,1.0,0.0,0,-0.916291,1.0,1.0,1,-0.510826,0.0,1.0,0,1.0,1.0,-100000000.0,-100000000.0
4,1.0,1.0,1,-0.510826,1.0,0.0,0,-0.916291,0.0,0.0,1,1.0,1.0,-100000000.0,-100000000.0
5,1.0,1.0,1,-0.510826,1.0,0.0,0,-0.916291,0.0,0.0,0,1.0,1.0,-100000000.0,-100000000.0


In [12]:
ac_check(tr2, mwc2, antecedents2, witnesses2, consequents)

The antecedent set is not minimal.


False

### Forest fire

In [13]:
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", match_dropped.bool() & lightning.bool(),
                                     event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}
    


In [14]:
def ff_disjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", match_dropped.bool() | lightning.bool(), event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}


In [15]:
antecedents_ff = {"match_dropped": 0.0}
witnesses_ff = ["lightning"]
consequents_ff = ["forest_fire"]
observations_ff = tensorize_observations({"match_dropped": 1.0, "lightning": 1.0})

In [16]:
with MultiWorldCounterfactual() as mwc_ff:
    with ExplainCauses(antecedents = antecedents_ff, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_ff,
                        consequents = consequents_ff):
        with condition(data = observations_ff):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_ff:
                    ff_conjunctive()

In [17]:
# In the conjunctive model 
# Each of the two factors is a but-for cause

ac_check(tr_ff, mwc_ff, antecedents_ff, witnesses_ff, consequents_ff)

The antecedent set is an actual cause.


True

In [18]:
# In the disjunctive model 
# there still would be fire if there was no lightning

with MultiWorldCounterfactual() as mwc_ffd:
    with ExplainCauses(antecedents = antecedents_ff, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_ff,
                        consequents = consequents_ff):
        with condition(data = observations_ff):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_ffd:
                    ff_disjunctive()

In [19]:
ac_check(tr_ffd, mwc_ffd, antecedents_ff, witnesses_ff, consequents_ff)

No resulting difference to the consequent in the sample.


In [20]:
# in the disjunctive model
# the actual cause is the composition of the two factors

antecedents_ffd2 = {"match_dropped": 0.0, "lightning":0.0}
witnesses_ffd2 = []

with MultiWorldCounterfactual() as mwc_ffd2:
    with ExplainCauses(antecedents = antecedents_ffd2, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_ffd2,
                        consequents = consequents_ff):
        with condition(data = observations_ff):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_ffd2:
                    ff_disjunctive()

In [21]:
ac_check(tr_ffd2, mwc_ffd2, antecedents_ffd2, witnesses_ffd2, consequents_ff)

The antecedent set is an actual cause.


True

### Doctors

In [22]:
def bc_function(mt, tt):
    condition1 = (mt == 1) & (tt == 1)
    condition2 = (mt == 1) & (tt == 0)
    condition3 = (mt == 0) & (tt == 1)
    condition4 = ~(condition1 | condition2 | condition3)

    output = torch.where(condition1, torch.tensor(3.0), torch.tensor(0.0))
    output = torch.where(condition2, torch.tensor(0.0), output)
    output = torch.where(condition3, torch.tensor(1.0), output)
    output = torch.where(condition4, torch.tensor(2.0), output)

    return output


def model_doctors():
    u_monday_treatment = pyro.sample("u_monday_treatment", dist.Bernoulli(0.5))

    monday_treatment = pyro.deterministic(
        "monday_treatment", u_monday_treatment, event_dim=0
    )

    tuesday_treatment = pyro.deterministic(
        "tuesday_treatment",
        torch.logical_not(monday_treatment).float(),
        event_dim=0,
    )

    bills_condition = pyro.deterministic(
        "bills_condition",
        bc_function(monday_treatment, tuesday_treatment),
        event_dim=0,
    )

    bill_alive = pyro.deterministic(
        "bill_alive", bills_condition.not_equal(3.0).float(), event_dim=0
    )

    return {
        "monday_treatment": monday_treatment,
        "tuesday_treatment": tuesday_treatment,
        "bills_condition": bills_condition,
        "bill_alive": bill_alive,
    }

In [23]:
antecedents_doc1 = {"monday_treatment": 0.0}
witnesses_doc = []
consequents_doc1 = ["tuesday_treatment"]
observations_doc = tensorize_observations({"u_monday_treatment": 1.0})

In [24]:
# The first actual causal link holds

with MultiWorldCounterfactual() as mwc_doc1:
    with ExplainCauses(antecedents = antecedents_doc1, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_doc,
                        consequents = consequents_doc1):
        with condition(data = observations_doc):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_doc1:
                    model_doctors()
                    
ac_check(tr_doc1, mwc_doc1, antecedents_doc1, witnesses_doc, consequents_doc1)

The antecedent set is an actual cause.


True

In [25]:
# So does the second

antecedents_doc2 = {"tuesday_treatment": 1.0}
consequents_doc2 = ["bill_alive"]


with MultiWorldCounterfactual() as mwc_doc2:
    with ExplainCauses(antecedents = antecedents_doc2, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_doc,
                        consequents = consequents_doc2):
        with condition(data = observations_doc):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_doc2:
                    model_doctors()

ac_check(tr_doc2, mwc_doc2, antecedents_doc2, witnesses_doc, consequents_doc2)

The antecedent set is an actual cause.


True

In [26]:
# The third does not, so transitivity fails!

with MultiWorldCounterfactual() as mwc_doc3:
    with ExplainCauses(antecedents = antecedents_doc1, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_doc,
                        consequents = consequents_doc2):
        with condition(data = observations_doc):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_doc3:
                    model_doctors()

ac_check(tr_doc3, mwc_doc3, antecedents_doc1, witnesses_doc, consequents_doc2)

No resulting difference to the consequent in the sample.


### Friendly fire


In [27]:

def model_friendly_fire():
    u_f4_PLGR_now = pyro.sample("u_f4_PLGR_now", dist.Bernoulli(0.5))
    u_f11_training = pyro.sample("u_f11_training", dist.Bernoulli(0.5))

    f4_PLGR_now = pyro.deterministic("f4_PLGR_now", u_f4_PLGR_now, event_dim=0)
    f11_training = pyro.deterministic(
        "f11_training", u_f11_training, event_dim=0
    )

    f6_PLGR_before = pyro.deterministic(
        "f6_PLGR_before", f4_PLGR_now, event_dim=0
    )
    f7_second_calculation = pyro.deterministic(
        "f7_second_calculation", f4_PLGR_now, event_dim=0
    )
    f13_battery_died = pyro.deterministic(
        "f13_battery_died",
        f6_PLGR_before.bool() & f7_second_calculation.bool(),
        event_dim=0,
    )

    f1_battery_change = pyro.deterministic(
        "f1_battery_change", f13_battery_died, event_dim=0
    )

    f12_PLGR_after = pyro.deterministic(
        "f12_PLGR_after", f1_battery_change, event_dim=0
    )

    f5_unaware = pyro.deterministic("f5_unaware", f11_training, event_dim=0)

    f14_wrong_position = pyro.deterministic(
        "f14_wrong_position", f5_unaware, event_dim=0
    )

    f9_mistake_call = pyro.deterministic(
        "f9_mistake_call",
            f12_PLGR_after.bool() & 
            f14_wrong_position.bool(),
        event_dim=0,
    )

    f3_fired = pyro.deterministic("f3_fired", f9_mistake_call, event_dim=0)

    f10_landed = pyro.deterministic(
        "f10_landed", f3_fired.bool() &  f9_mistake_call.bool(), event_dim=0
    )

    f2_killed = pyro.deterministic("f2_killed", f10_landed, event_dim=0)

    return {
        "f1_battery_change": f1_battery_change,
        "f2_killed": f2_killed,
        "f3_fired": f3_fired,
        "f4_PLGR_now": f4_PLGR_now,
        "f5_unaware": f5_unaware,
        "f6_PLGR_before": f6_PLGR_before,
        "f7_second_calculation": f7_second_calculation,
        "f9_mistake_call": f9_mistake_call,
        "f10_landed": f10_landed,
        "f11_training": f11_training,
        "f12_PLGR_after": f12_PLGR_after,
        "f13_battery_died": f13_battery_died,
        "f14_wrong_position": f14_wrong_position,
    }

In [28]:
antecedents_fi1 = {"f6_PLGR_before": 0.0, "f7_second_calculation": 0.0}
consequents_fi = ["f2_killed"]
witnesses_fi  = ["f4_PLGR_now","f5_unaware", "f11_training", "f14_wrong_position"]
observations_fi = tensorize_observations({"u_f4_PLGR_now": 1.0, "u_f11_training": 1.0})

In [29]:
with MultiWorldCounterfactual() as mwc_fi1:
    with ExplainCauses(antecedents = antecedents_fi1, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_fi,
                        consequents = consequents_fi):
        with condition(data = observations_fi):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_fi1:
                    model_friendly_fire()              

In [30]:
ac_check(tr_fi1, mwc_fi1, antecedents_fi1, witnesses_fi, consequents_fi)

The antecedent set is not minimal.


False

In [31]:
antecedents_fi2 = {"f6_PLGR_before": 0.0}

with MultiWorldCounterfactual() as mwc_fi2:
    with ExplainCauses(antecedents = antecedents_fi2, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_fi,
                        consequents = consequents_fi):
        with condition(data = observations_fi):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_fi2:
                    model_friendly_fire()    
                    
ac_check(tr_fi2, mwc_fi2, antecedents_fi2, witnesses_fi, consequents_fi)

The antecedent set is an actual cause.


True

### Voting


In [32]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))
    u_vote5 = pyro.sample("u_vote5", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    vote5 = pyro.deterministic("vote5", u_vote5, event_dim=0)
    
    outcome = pyro.deterministic("outcome", vote0 + vote1 + vote2 + vote3 + vote4 + vote5 > 3).float()
    return {"outcome": outcome}


In [33]:
antecedents_v = {"vote0":0.0}
outcome_v = ["outcome"]
witnesses_v = [f"vote{i}" for i in range(1,6)]
observations_v1 = tensorize_observations(dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=0., u_vote5=0.))

In [34]:
with MultiWorldCounterfactual() as mwc_v1:
    with ExplainCauses(antecedents = antecedents_v, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_v,
                        consequents = outcome_v):
        with condition(data = observations_v1):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_v1:
                    voting_model()

# if you're one of four voters who voted for, you are an actual cause
# of the outcome

ac_check(tr_v1, mwc_v1, antecedents_v, witnesses_v1, outcome_v)

The antecedent set is an actual cause.


True

In [None]:
# if you're one of five voters who voted for, you are not an actual cause
# of the outcome

observations_v2 = tensorize_observations(dict(u_vote0=1., u_vote1=1., u_vote2=1.,
                        u_vote3=1., u_vote4=1., u_vote5=0.))

with MultiWorldCounterfactual() as mwc_v1:
    with ExplainCauses(antecedents = antecedents_v, antecedent_bias= antencedent_bias,
                        witnesses = witnesses_v,
                        consequents = outcome_v):
        with condition(data = observations_v2):
            with pyro.plate("sample", 200):
                with pyro.poutine.trace() as tr_v1:
                    voting_model()