In [35]:
%reload_ext autoreload
%autoreload 2
%pdb off

import functools

import torch
import pyro
import pyro.distributions as dist

from causal_pyro.indexed.ops import IndexSet, gather, indices_of, scatter
from causal_pyro.interventional.handlers import do
from causal_pyro.counterfactual.handlers import MultiWorldCounterfactual, Preemptions

Automatic pdb calling has been turned OFF


In [36]:
def voting_model():
    u_vote0 = pyro.sample("u_vote0", dist.Bernoulli(0.6))
    u_vote1 = pyro.sample("u_vote1", dist.Bernoulli(0.6))
    u_vote2 = pyro.sample("u_vote2", dist.Bernoulli(0.6))
    u_vote3 = pyro.sample("u_vote3", dist.Bernoulli(0.6))
    u_vote4 = pyro.sample("u_vote4", dist.Bernoulli(0.6))

    vote0 = pyro.deterministic("vote0", u_vote0, event_dim=0)
    vote1 = pyro.deterministic("vote1", u_vote1, event_dim=0)
    vote2 = pyro.deterministic("vote2", u_vote2, event_dim=0)
    vote3 = pyro.deterministic("vote3", u_vote3, event_dim=0)
    vote4 = pyro.deterministic("vote4", u_vote4, event_dim=0)
    return vote0 + vote1 + vote2 + vote3 + vote4 >= 3

In [44]:
def preempt_with_factual(value: torch.Tensor, *, antecedent: str = "", event_dim: int = 0):
    factual_value = gather(value, IndexSet(**{antecedent: {0}}), event_dim=event_dim)
    return scatter({
        IndexSet(**{antecedent: {0}}): factual_value,
        IndexSet(**{antecedent: {1}}): factual_value,
    }, event_dim=event_dim)

antecedents = dict(vote0=lambda v: 1 - v)
preemptions = {f"vote{i}": functools.partial(preempt_with_factual, antecedent="vote0") for i in range(1, 5)}
observations = dict(u_vote0=1., u_vote1=0., u_vote2=0., u_vote3=1., u_vote4=1.)

# @pyro.infer.infer_discrete(first_available_dim=-7)
# @pyro.infer.config_enumerate
@MultiWorldCounterfactual()
@do(actions=antecedents)
@Preemptions(actions=preemptions)
@pyro.condition(data={k: torch.as_tensor(v) for k, v in observations.items()})
def ac_voting_model():
    consequent = voting_model()
    intervened_consequent = gather(consequent, IndexSet(vote0={1}))
    observed_consequent = gather(consequent, IndexSet(vote0={0}))
    consequent_differs = intervened_consequent != observed_consequent
    pyro.factor("consequent_differs", torch.where(consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))
    print(indices_of(consequent), indices_of(consequent_differs))
    return intervened_consequent, observed_consequent, consequent_differs

print(ac_voting_model())

IndexSet({'vote0': {0, 1}}) IndexSet({})
(tensor([[[[[False]]]]]), tensor([[[[[True]]]]]), tensor([[[[[True]]]]]))
