In [46]:
%reload_ext autoreload
%autoreload 2
%pdb off

import functools

import torch
from typing import Dict, List, Optional, Tuple, Union, TypeVar, Callable

import pandas as pd

import pyro
import pyro.distributions as dist

from causal_pyro.indexed.ops import IndexSet, gather, indices_of, scatter
from causal_pyro.interventional.handlers import do
from causal_pyro.counterfactual.handlers import MultiWorldCounterfactual, Preemptions

Automatic pdb calling has been turned OFF


In [59]:
class HalpernPearlModifiedApproximate:

    def __init__(self, model: Callable,
                 antecedents: Union[Dict[str, torch.Tensor], List[str]],
                 witness_candidates: List[str],
                 observations: Dict[str, torch.Tensor],
                 sample_size: int = 100,
                 intervention, counterfactual, preemptions,
                 event_dim: int = 0):
        
        self.model = model
        self.antecedents = antecedents
        self.witness_candidates = witness_candidates
        self.observations = observations
        self.sample_size = sample_size

        self.antecedents_dict = (
            self.antecedents if isinstance(self.antecedents, dict)
            else self.revert_antecedents(self.antecedents)
        )
    
        self.preemptions = {candidate: functools.partial(self.preempt_with_factual,
                                             antecedents = self.antecedents) for 
                                             candidate in self.witness_candidates}
        

    @staticmethod
    def revert_antecedents(antecedents: List[str]) -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
        return {antecedent: lambda v: 1 - v for antecedent in antecedents}

    @staticmethod   
    def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
        if antecedents is None:
            antecedents = []

        antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

        factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                                event_dim=event_dim)
            
        return scatter({
            IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
            IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
        }, event_dim=event_dim)
        
        
    def __call__(self, *args, **kwargs):
        
        

    


In [2]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    return {"outcome": vote0 + vote1 + vote2 + vote3 + vote4 >= 3}

# dict here seems handy 

In [3]:
model = voting_model

In [25]:

def revert_antecedents(antecedents: List[str]) -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
        return {antecedent: lambda v: 1 - v for antecedent in antecedents}


antecedents = ["vote0", "vote3"]

antecedents_reverted = revert_antecedents(antecedents)


In [26]:

def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
    if antecedents is None:
        antecedents = []

    antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

    factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                            event_dim=event_dim)
        
    return scatter({
        IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
        IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
    }, event_dim=event_dim)


In [27]:
witness_candidates = [f"vote{i}" for i in [1,2,4]]

preemptions = {candidate: functools.partial(preempt_with_factual,
                                             antecedents = antecedents) for 
                                             candidate in witness_candidates}


{'vote1': functools.partial(<function preempt_with_factual at 0x7f73640152d0>, antecedents=['vote0', 'vote3']), 'vote2': functools.partial(<function preempt_with_factual at 0x7f73640152d0>, antecedents=['vote0', 'vote3']), 'vote4': functools.partial(<function preempt_with_factual at 0x7f73640152d0>, antecedents=['vote0', 'vote3'])}


In [49]:

observations = dict(u_vote0=1., u_vote1=0., u_vote2=0., u_vote3=1., u_vote4=1.)

sample_size = 10

In [50]:
with pyro.poutine.trace() as tr:
  with MultiWorldCounterfactual():
      with do(actions=antecedents_reverted):
        with Preemptions(actions = preemptions):
          with pyro.condition(data={k: torch.as_tensor(v) for k, v in observations.items()}):
              with pyro.plate("plate", sample_size):
                consequent = model()['outcome']
                intervened_consequent = gather(consequent, IndexSet(**{ant: {1} for ant in antecedents}))
                observed_consequent = gather(consequent, IndexSet(**{ant: {0} for ant in antecedents}))
                consequent_differs = intervened_consequent != observed_consequent   
                pyro.factor("consequent_differs", torch.where(consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))

                print(
                      {"intervened_consequent": intervened_consequent.squeeze().tolist(),
                      "observed_consequent": observed_consequent.squeeze().tolist(),
                      "consequent_differs": consequent_differs.squeeze().tolist()}
                )


{'intervened_consequent': [False, False, False, False, False, False, False, False, False, False], 'observed_consequent': [True, True, True, True, True, True, True, True, True, True], 'consequent_differs': [True, True, True, True, True, True, True, True, True, True]}


In [51]:
witness_keys = ["__split_" + candidate for candidate in witness_candidates]

witness_dict = {key: tr.trace.nodes[key]['value']  for key in witness_keys}

witness_dict['consequent_differs'] = consequent_differs.squeeze()

witness_df = pd.DataFrame(witness_dict)

print(witness_df)



   __split_vote1  __split_vote2  __split_vote4  consequent_differs
0              0              0              0                True
1              1              0              1                True
2              0              1              0                True
3              0              0              1                True
4              1              0              0                True
5              1              0              1                True
6              1              1              1                True
7              0              0              0                True
8              0              0              1                True
9              0              1              0                True


In [118]:
def actual_causality(model, 
                    antecedents: List[str],
                    witness_candidates: List[str], 
                    outcome: str, 
                    observations: Dict[str, torch.Tensor],
                    *, event_dim: int = 0):

SyntaxError: incomplete input (1201503155.py, line 6)

In [140]:


# @pyro.infer.infer_discrete(first_available_dim=-7)
# @pyro.infer.config_enumerate
@MultiWorldCounterfactual()
@do(actions=antecedents2)
@Preemptions(actions=preemptions2)
@pyro.condition(data={k: torch.as_tensor(v) for k, v in observations.items()})
def ac_voting_model():
    consequent = voting_model()['outcome']
    intervened_consequent = gather(consequent, IndexSet(vote0={1}))
    observed_consequent = gather(consequent, IndexSet(vote0={0}))
    consequent_differs = intervened_consequent != observed_consequent
    pyro.factor("consequent_differs", torch.where(consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))
    print(indices_of(consequent), indices_of(consequent_differs))
    return intervened_consequent, observed_consequent, consequent_differs

print(ac_voting_model())

TypeError: list indices must be integers or slices, not str