In [11]:
import functools

import numpy as np

import torch
from typing import Dict, List, Optional,  Union, Callable

import pandas as pd

import pyro
import pyro.distributions as dist

import random


from causal_pyro.indexed.ops import IndexSet, gather, indices_of, scatter
from causal_pyro.interventional.handlers import do
from causal_pyro.counterfactual.handlers import MultiWorldCounterfactual, Preemptions

In [12]:
class HalpernPearlModifiedApproximate:

    def __init__(
        self, 
        model: Callable,
        antecedents: Union[Dict[str, torch.Tensor], List[str]],
        outcome: str,
        witness_candidates: List[str],
        observations: Optional[Dict[str, torch.Tensor]],
        sample_size: int = 100,
        event_dim: int = 0
        ):
        
        self.model = model
        self.antecedents = antecedents
        self.outcome = outcome
        self.witness_candidates = witness_candidates
        self.observations = observations
        self.sample_size = sample_size

        self.antecedents_dict = (
            self.antecedents if isinstance(self.antecedents, dict)
            else self.revert_antecedents(self.antecedents)
        )
    
        self.preemptions = {candidate: functools.partial(self.preempt_with_factual,
                                             antecedents = self.antecedents) for 
                                             candidate in self.witness_candidates}
        

    @staticmethod
    def revert_antecedents(antecedents: List[str]) -> Dict[str, Callable[[torch.Tensor], torch.Tensor]]:
        return {antecedent: (lambda v: 1 - v) for antecedent in antecedents}

    @staticmethod   
    def preempt_with_factual(value: torch.Tensor, *,
                          antecedents: List[str] = None, event_dim: int = 0):
    
        if antecedents is None:
            antecedents = []

        antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]

        factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),
                                event_dim=event_dim)
            
        return scatter({
            IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,
            IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,
        }, event_dim=event_dim)
        
        
    def __call__(self, *args, **kwargs):
        with pyro.poutine.trace() as trace:
            with MultiWorldCounterfactual():
                with do(actions=self.antecedents_dict):
                    with Preemptions(actions = self.preemptions):
                        with pyro.condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):
                            with pyro.plate("plate", self.sample_size):
                                self.consequent = self.model()[self.outcome]
                                self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.antecedents}))
                                self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.antecedents}))
                                self.consequent_differs = self.intervened_consequent != self.observed_consequent   
                                pyro.factor("consequent_differs", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))
                            
        self.trace = trace.trace

        # slightly hacky solution for odd witness candidate sets
        if  isinstance(self.consequent_differs.squeeze().tolist(), bool):
            self.existential_but_for = self.consequent_differs.squeeze()
        else:
            #if (len(self.consequent_differs.squeeze().tolist() )>1):
            self.existential_but_for = any(self.consequent_differs.squeeze().tolist()                )  

            

        witness_dict = dict()
        if self.witness_candidates:
            witness_keys = ["__split_" + candidate for candidate in self.witness_candidates]
            witness_dict = {key: self.trace.nodes[key]['value']  for key in witness_keys}
            
        witness_dict['observed'] = self.observed_consequent.squeeze()
        witness_dict['intervened'] = self.intervened_consequent.squeeze()
        witness_dict['consequent_differs'] = self.consequent_differs.squeeze()

        # slightly hacky as above
        self.witness_df = pd.DataFrame(witness_dict) if self.witness_candidates else witness_dict

    


In [13]:
def ff_conjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", torch.logical_and(match_dropped, lightning), event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

def ff_disjunctive():
    u_match_dropped = pyro.sample("u_match_dropped", dist.Bernoulli(0.5))
    u_lightning = pyro.sample("u_lightning", dist.Bernoulli(0.5))

    match_dropped = pyro.deterministic("match_dropped",
                                       u_match_dropped, event_dim=0)
    lightning = pyro.deterministic("lightning", u_lightning, event_dim=0)
    forest_fire = pyro.deterministic("forest_fire", torch.logical_or(match_dropped, lightning), event_dim=0).float()

    return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}


In [14]:

def factivity_check(model, antecedents_dict, outcome_dict, observations):
    
    with pyro.condition(data={k: torch.as_tensor(v) for k, v in observations.items()}):
        output = model()
        factivity_tensors = {k: torch.as_tensor(v) for k, v in list(antecedents_dict.items()) + list(outcome_dict.items())}
        return all([factivity_tensors[key] == output[key] for key in factivity_tensors.keys()])



In [15]:
def part_of_minimal_cause(model, antecedents, outcome, nodes, observations, runs_n):

    cache = []
    minimal_antecedents = []

    for step in range(1,runs_n):
        if outcome in nodes:
            nodes.remove(outcome)

        companion_size = random.randint(0,len(nodes))
        companion_candidates = random.sample(nodes, companion_size)

        if set(companion_candidates) in cache:
            continue
        
        cache.append(set(companion_candidates))

        witness_candidates = [node for node in nodes if 
                                node not in antecedents and 
                                node != outcome and 
                                    node not in companion_candidates]
        
        HPM = HalpernPearlModifiedApproximate(
        model = model,
        antecedents = companion_candidates,
        outcome =  outcome,
        witness_candidates = witness_candidates,
        observations = observations,
        sample_size = 1000)
    
        HPM()

        if  not HPM.existential_but_for:
            continue
        
        subset_is_a_minimal_cause = any([s.issubset(set(HPM.antecedents)) for s in minimal_antecedents])
             
        if subset_is_a_minimal_cause:
            continue
        minimal_antecedents.append(set(HPM.antecedents))

        
        for s in minimal_antecedents:
            if set(HPM.antecedents).issubset(s) and s != set(HPM.antecedents):
                minimal_antecedents.remove(s)  


    return {"sufficient_cause": any([set(antecedents).issubset(s) for s in minimal_antecedents]),
            "actual_cause": set(antecedents) in minimal_antecedents,
                "minimal_antecedents" : minimal_antecedents, "cache": cache}
       


In [16]:
part_of_minimal_cause(model = ff_conjunctive, 
                        antecedents = ['lightning'],
                        outcome =  'forest_fire',
                        nodes =  ['match_dropped', 'lightning'],
                        observations = {"u_match_dropped": 1., "u_lightning": 1.},
                        runs_n = 20)

part_of_minimal_cause(model = ff_disjunctive, 
                        antecedents = ['lightning'],
                        outcome =  'forest_fire',
                        nodes =  ['match_dropped', 'lightning'],
                        observations = {"u_match_dropped": 1., "u_lightning": 1.},
                        runs_n = 20)

{'sufficient_cause': True,
 'actual_cause': False,
 'minimal_antecedents': [{'lightning', 'match_dropped'}],
 'cache': [set(),
  {'lightning', 'match_dropped'},
  {'match_dropped'},
  {'lightning'}]}

In [17]:
def overlap_with_cause(model, antecedents, outcome, nodes, observations, runs_n = 20):
    
    minimal_ante = part_of_minimal_cause(model, antecedents, outcome, nodes, observations, runs_n)['minimal_antecedents']
    antecedents_set = set(antecedents)
    
    overlaps = [antecedents_set.intersection(s) for s in minimal_ante if antecedents_set.intersection(s)]
    overlap = any(overlaps)
    return {"overlap": overlap, "overlaps": overlaps    }
    

overlap_with_cause(model = ff_disjunctive, 
                        antecedents = ['lightning', 'blah'],
                        outcome =  'forest_fire',
                        nodes =  ['match_dropped', 'lightning'],
                        observations = {"u_match_dropped": 1., "u_lightning": 1.},
                        runs_n = 20)

{'overlap': True, 'overlaps': [{'lightning'}]}

In [18]:
def ensurer(model, exogenous_variables, antecedents_dict, outcome_dict, runs_n):

    settings_cache = []
    intervened_consequent = []

    outcome = list(outcome_dict.keys())[0]
    antecedents = [key for key in antecedents_dict.keys()]

    for step in range(1,runs_n):
        
        random_setting = [random.choice([0., 1.]) for _ in range(len(exogenous_variables))]
        if random_setting in settings_cache:
            continue
        
        settings_cache.append(random_setting)

        observations = {var: val for var, val in zip(exogenous_variables, random_setting)}

        with pyro.condition(data={k: torch.as_tensor(v) for k, v in observations.items()}):
                with MultiWorldCounterfactual():
                    with do(actions=antecedents_dict):
                        intervened_consequent.append(
                             gather(model()[outcome], 
                                    IndexSet(**{ant: {1} for ant in antecedents})).squeeze().item())
                        
    return {"ensurer": all(intervened_consequent),
            "settings_cache": settings_cache, 
            "intervened_consequent": intervened_consequent}

    print(settings_cache)
    print(observations)     
    print(intervened_consequent)   
    print(all(intervened_consequent))




ensurer(model = ff_conjunctive,
        exogenous_variables = ["u_match_dropped", "u_lightning"],
        antecedents_dict = {"match_dropped": 1., "lightning": 1.},
        outcome_dict = {"forest_fire": 1.},
                        runs_n = 10)

{'ensurer': True,
 'settings_cache': [[1.0, 1.0], [0.0, 0.0], [0.0, 1.0], [1.0, 0.0]],
 'intervened_consequent': [1.0, 1.0, 1.0, 1.0]}

In [58]:
def sufficient_cause(model, exogenous_variables, antecedents_dict, outcome_dict, nodes, observations, runs_n):

    factivity = factivity_check(model = model,
                antecedents_dict = antecedents_dict,
                outcome_dict = outcome_dict, 
                observations = observations)
    
    if not factivity:
        return {"sufficient_cause": False, "reason": {"factivity": False}}   

    
    ensure = ensurer(model = model,
        exogenous_variables = exogenous_variables,
        antecedents_dict = antecedents_dict,
        outcome_dict = outcome_dict,
        runs_n = runs_n)['ensurer']
    
    if not ensure:
        return {"sufficient_cause": False, "reason": {"ensure": False}}
    
    overlap = overlap_with_cause(model = model,
                        antecedents = [key for key in antecedents_dict.keys()],
                        outcome =  list(outcome_dict.keys())[0],
                        nodes =  nodes,
                        observations = observations,
                        runs_n= runs_n)
    
    if not overlap['overlap']:
        return {"sufficient_cause": False, "reason": {"overlap": False}}


    # minimality check starts here
    antecedents = [key for key in antecedents_dict.keys()]
    subsets = [[]]
    for node in antecedents:
        subsets.extend([subset + [node] for subset in subsets])
    subsets.pop()

    
    for subset in subsets:
    
        subset_ensure = ensurer(model = model,
            exogenous_variables = exogenous_variables,
            antecedents_dict = {key: antecedents_dict[key] for key in subset},
            outcome_dict = outcome_dict,
            runs_n = runs_n)['ensurer']    
    
        if not subset_ensure:
            continue

        subset_overlap = overlap_with_cause(model = model,
                    antecedents = subset,
                    outcome =  list(outcome_dict.keys())[0],
                    nodes =  nodes,
                    observations = observations,
                    runs_n= runs_n)['overlap']

        if subset_ensure and subset_overlap:
            
            return {"sufficient_cause": False, "reason": {"minimality": False, "subset": subset}}
      # minimality check ends here 

                
        
sufficient_cause(model = ff_conjunctive,
                 exogenous_variables= ["u_match_dropped", "u_lightning"],
                    antecedents_dict = {"match_dropped": 1., "lightning": 1.},
                    outcome_dict = {"forest_fire": 1.},
                    nodes = ["match_dropped", "lightning"],
                    observations= {"u_match_dropped": 1., "u_lightning": 1.},
                    runs_n = 20)


{'sufficient_cause': False,
 'reason': {'minimality': False, 'subset': ['match_dropped']}}

In [313]:

settings_cache = []
HPMs = []
minimal_antecedents_cache = []

model = ff_disjunctive
exogenous_variables = ["u_match_dropped", "u_lightning"]
antecedents_dict = {"match_dropped": 1., "lightning": 1.}
outcome_dict = {"forest_fire": 1.}
nodes = ["match_dropped", "lightning"]
witness_candidates = []



antecedents = [key for key in antecedents_dict.keys()]
outcome = list(outcome_dict.keys())[0]
if outcome in nodes:
    nodes.remove(outcome)


pyro.set_rng_seed(0)
random_setting = [random.choice([0., 1.]) for _ in range(len(exogenous_variables))]

#if random_setting in settings_cache:
#    continue

settings_cache.append(random_setting)

observations = {var: val for var, val in zip(exogenous_variables, random_setting)}

print(observations)

factivity = factivity_check(model = model,
                antecedents_dict = antecedents_dict,
                outcome_dict = outcome_dict, 
                observations = observations)


print(factivity)
#if not factivity:
#    continue

part_of_minimal = part_of_minimal_cause(model = model,
                        antecedents = antecedents,
                        outcome =  outcome,
                        nodes = nodes,
                        observations = observations,
                        runs_n = 20)['sufficient_cause']




{'u_match_dropped': 1.0, 'u_lightning': 1.0}
True


In [None]:


   


print(outcome)


if random_setting not in settings:

   

    HPM = HalpernPearlModifiedApproximate(
        model=model,
        antecedents=antecedents,
        outcome=outcome,
        witness_candidates=witness_candidates,
        observations=observations,
        sample_size=1000
    )
    HPM()
    HPMs.append(HPM)


    if  HPM.existential_but_for:

        subset_in_cache = any([s.issubset(set(HPM.antecedents)) for s in self.minimal_antecedents_cache])
                if not subset_in_cache:
                    self.minimal_antecedents_cache.append(set(HPM.antecedents))

                    if self.antecedent in HPM.antecedents:
                        self.antecedent_sizes.append(len(HPM.antecedents))
                        self.existential_but_fors.append(HPM.existential_but_for)
                        self.minimal_witness_sizes.append(HPM.minimal_witness_size)
                        self.responsibilities.append(HPM.responsibility_internal)






In [None]:


#factivity check to avoid needles computation




In [None]:


[attr for attr in dir(step_HPM) if not attr.startswith('__')]

    settings.append(random_setting)
    HPMs.append(HPM(model, antecedents, outcome, observations, witness_candidates))
