In [1]:
%env CUDA_VISIBLE_DEVICES=-1

import os

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

import matplotlib.pyplot as plt

import pandas as pd


from chirho.observational.handlers import condition
from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual
from chirho.explainable.handlers import SearchForExplanation, SearchForNS
                                            
from chirho.indexed.ops import (IndexSet, gather, indices_of) 
from chirho.interventional.handlers import do

env: CUDA_VISIBLE_DEVICES=-1


In [2]:
def ff_disjunctive():
        match_dropped = pyro.sample("match_dropped", dist.Bernoulli(0.7)) # notice uneven probs here
        lightning = pyro.sample("lightning", dist.Bernoulli(0.4))

        forest_fire = pyro.deterministic("forest_fire", torch.max(match_dropped, lightning), event_dim=0)

        return {"match_dropped": match_dropped, "lightning": lightning,
            "forest_fire": forest_fire}

observations = {"match_dropped": torch.tensor(1.), 
                "lightning": torch.tensor(0.),
                "forest_fire": torch.tensor(1.)}

antecedents = {"match_dropped": 0.0} 

witnesses = {} # ignore witnesses for now

consequents = {"forest_fire": constraints.boolean}


In [3]:
with MultiWorldCounterfactual() as ffd_mwc:  # needed to keep track of multiple scenarios
    with SearchForExplanation(antecedents = antecedents, 
                              witnesses = witnesses,
                              consequents = consequents,
                              consequent_scale= 1e-8):
        with condition(data = observations):
            with pyro.plate("sample", 10): # run a few times
                with pyro.poutine.trace() as ffd_tr:
                    ff_disjunctive()

ffd_tr.trace.compute_log_prob() 
ffd_nd = ffd_tr.trace.nodes

with ffd_mwc: 
    original_intervened = gather(ffd_nd['__consequent_forest_fire']['log_prob'], 
                IndexSet(**{'match_dropped': {1}}))
    print(original_intervened)

tensor([[[[[-inf, 0., -inf, 0., 0., 0., 0., 0., 0., -inf]]]]])


In [4]:
with MultiWorldCounterfactual() as ffd_ns_mwc:  # needed to keep track of multiple scenarios
    with SearchForNS(antecedents = antecedents, 
                              witnesses = witnesses,
                              consequents = consequents,
                              consequent_scale= 1e-8):
        with condition(data = observations):
            with pyro.plate("sample", 10): # run a few times
                with pyro.poutine.trace() as ffd_ns_tr:
                    ff_disjunctive()

ffd_ns_tr.trace.compute_log_prob() 
ffd_ns_nd = ffd_ns_tr.trace.nodes

with ffd_ns_mwc: 
    new_n_intervened = gather(ffd_ns_nd['__consequent_forest_fire']['log_prob'], 
                IndexSet(**{'match_dropped': {1}}))

    print(new_n_intervened)
    print(indices_of(ffd_ns_nd['forest_fire']['value']))

assert new_n_intervened.shape ==  original_intervened.shape 

tensor([[[[[0., -inf, -inf, 0., -inf, -inf, -inf, 0., -inf, -inf]]]]])
IndexSet({'match_dropped': {0, 1}})
