In [58]:
%reload_ext autoreload
%autoreload 2
%pdb off

import functools

import torch
from typing import Dict, List, Optional, Tuple, Union, TypeVar, Callable

import pandas as pd

import pyro
import pyro.distributions as dist

from causal_pyro.indexed.ops import IndexSet, gather, indices_of, scatter
from causal_pyro.interventional.handlers import do
from causal_pyro.counterfactual.handlers import MultiWorldCounterfactual, Preemptions

Automatic pdb calling has been turned OFF


In [44]:
class HalpernPearlModifiedApproximate:

    def __init__(
        self, 
        model: Callable,
        antecedents: Union[Dict[str, torch.Tensor], List[str]],
        outcome: str,
        witness_candidates: List[str],
        observations: Optional[Dict[str, torch.Tensor]],
        sample_size: int = 100,
        event_dim: int = 0
        ):
        
        self.model = model
        self.antecedents = antecedents
        self.witness_candidates = witness_candidates
        self.observations = observations
        self.sample_size = sample_size

        self.antecedents_dict = (
            self.antecedents if isinstance(self.antecedents, dict)
            else self.revert_antecedents(self.antecedents)
        )
    
        self.preemptions = {candidate: functools.partial(self.preempt_with_factual,
                                             antecedents = self.antecedents) for 
                                             candidate in self.witness_candidates}
        

    @staticmethod
    def revert_antecedents(antecedents: List[str]) -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
        return {antecedent: lambda v: 1 - v for antecedent in antecedents}

    @staticmethod   
    def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
        if antecedents is None:
            antecedents = []

        antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

        factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                                event_dim=event_dim)
            
        return scatter({
            IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
            IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
        }, event_dim=event_dim)
        
        
    def __call__(self, *args, **kwargs):
        with pyro.poutine.trace() as trace:
            with MultiWorldCounterfactual():
                with do(actions=self.antecedents_dict):
                    with Preemptions(actions = self.preemptions):
                        with pyro.condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):
                            with pyro.plate("plate", self.sample_size):
                                self.consequent = self.model()['outcome']
                                self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in antecedents}))
                                self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in antecedents}))
                                self.consequent_differs = self.intervened_consequent != self.observed_consequent   
                                pyro.factor("consequent_differs", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))
                            
        self.trace = trace.trace
        self.mean_consequent_difference = self.consequent_differs.squeeze().float().mean().item()
        
        witness_keys = ["__split_" + candidate for candidate in self.witness_candidates]
        witness_dict = {key: self.trace.nodes[key]['value']  for key in witness_keys}
        witness_dict['consequent_differs'] = self.consequent_differs.squeeze()

        self.witness_df = pd.DataFrame(witness_dict)

    


In [7]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    return {"outcome": vote0 + vote1 + vote2 + vote3 + vote4 >= 3}

# dict here seems handy 

In [76]:
votingHPM = HalpernPearlModifiedApproximate(
    model = voting_model,
    antecedents = ["vote0"],
    outcome = "outcome",
    witness_candidates = [f"vote{i}" for i in [1,2,3, 4]],
    observations = dict(u_vote0=1., u_vote1=0., u_vote2=0., u_vote3=1., u_vote4=1.),
    sample_size = 1000)

In [77]:
votingHPM()

print(
votingHPM.mean_consequent_difference
)

votingHPM.witness_df

1.0


Unnamed: 0,__split_vote1,__split_vote2,__split_vote3,__split_vote4,consequent_differs
0,0,0,1,1,True
1,1,0,0,0,True
2,0,1,1,0,True
3,1,0,0,0,True
4,0,1,0,1,True
...,...,...,...,...,...
995,0,1,0,1,True
996,1,1,1,1,True
997,1,1,1,1,True
998,1,1,0,1,True
