In [1]:
import pyro
from typing import Mapping, Union, TypeVar, List, Optional
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from chirho.interventional.ops import Intervention

T = TypeVar("T")

In [4]:
class InferSupports(pyro.poutine.messenger.Messenger):
    supports: Mapping[str, pyro.distributions.constraints.Constraint]

    def __init__(
        self,
        antecedents: Optional[
            Union[
                List[str],
                Mapping[str, Intervention[T]],
                Mapping[str, constraints.Constraint],
            ]
        ] = None,
        witnesses: Optional[
            Union[
                List[str],
                Mapping[str, Intervention[T]],
                Mapping[str, constraints.Constraint],
            ]
        ] = None,
        consequents: Optional[
            Union[
                List[str],
                Mapping[str, Intervention[T]],
                Mapping[str, constraints.Constraint],
            ]
        ] = None,
    ):
        super(InferSupports, self).__init__()

        for group in ["antecedents", "witnesses", "consequents"]:
            setattr(self, group, self._extract_keys(locals()[group]))

        self.supports = {}

    def _extract_keys(
        self,
        data: Optional[
            Union[List[str], Mapping[str, Union[T, constraints.Constraint]]]
        ],
    ) -> Optional[List[str]]:
        if data is None:
            return None
        elif isinstance(data, dict):
            return list(data.keys())
        else:
            return data

    def _pyro_post_sample(self, msg: dict) -> None:
        if not pyro.poutine.util.site_is_subsample(msg):
            self.supports[msg["name"]] = msg["fn"].support

    def __exit__(self, exc_type, exc_value, traceback):
        for group in ["antecedents", "witnesses", "consequents"]:
            keys = getattr(self, group)
            setattr(self, group, {})
            if keys:
                if not all(key in self.supports for key in keys):
                    raise ValueError(
                        f"Invalid keys in {group}. Ensure that all keys exist in self.supports."
                    )

                setattr(self, group, {key: self.supports[key] for key in keys})

        return super(InferSupports, self).__exit__(exc_type, exc_value, traceback)

In [5]:
def mixed_supports_model():
    uniform_var = pyro.sample("uniform_var", dist.Uniform(1, 10))
    normal_var = pyro.sample("normal_var", dist.Normal(3, 15))
    bernoulli_var = pyro.sample("bernoulli_var", dist.Bernoulli(0.5))
    positive_var = pyro.sample("positive_var", dist.LogNormal(0, 1))


with InferSupports() as s1:
    mixed_supports_model()

print(s1.supports)


with InferSupports(antecedents=["uniform_var"]) as s2:
    mixed_supports_model()

print("antecedents", s2.antecedents)


with InferSupports(antecedents={}, witnesses=["normal_var"]) as s3:
    mixed_supports_model()

print(s3.antecedents, s3.witnesses, s3.consequents)

TypeError: 'NoneType' object is not iterable

In [8]:
import pytest
import pyro
import pyro.distributions as dist
from itertools import chain, combinations


def mixed_supports_model():
    def model():
        uniform_var = pyro.sample("uniform_var", dist.Uniform(1, 10))
        normal_var = pyro.sample("normal_var", dist.Normal(3, 15))
        bernoulli_var = pyro.sample("bernoulli_var", dist.Bernoulli(0.5))
        positive_var = pyro.sample("positive_var", dist.LogNormal(0, 1))

    return model


options = [
    None,
    [],
    ["uniform_var"],
    ["uniform_var", "normal_var"],
    ["uniform_var", "normal_var", "bernoulli_var"],
    ["uniform_var", "normal_var", "bernoulli_var", "positive_var"],
    {},
    {"positive_var": 5.0},
    {"uniform_var": 5.0, "bernoulli_var": 5.0},
    {
        "uniform_var": constraints.interval(1, 10),
        "bernoulli_var": constraints.interval(0, 1),
    },  # misspecified on purpose, should make no damage
]


with InferSupports() as s1:
    mixed_supports_model()

print(s1.supports)

antecedents = options[1]
witnesses = options[1]
consequents = options[2]

with InferSupports(
    antecedents=antecedents, witnesses=witnesses, consequents=consequents
) as s:
    mixed_supports_model()


print(s.supports)

{}
['uniform_var']
{}


ValueError: Invalid keys in consequents. Ensure that all keys exist in self.supports.

In [None]:
def test_infer_supports(antecedents, witnesses, consequents, mixed_supports_model):
    with InferSupports(
        antecedents=antecedents, witnesses=witnesses, consequents=consequents
    ) as s:
        mixed_supports_model()

    # if antecedents is not None:
    #     assert s.antecedents.keys() ==  antecedents
    #     for key in antecedents:
    #         assert isinstance(s.supports[key], constraints.Constraint)
    #         assert s.supports[key] == antecedents[key]
    # if witnesses is not None:
    #     assert s.witnesses.keys() == witnesses
    #     for key in witnesses:
    #         assert isinstance(s.supports[key], constraints.Constraint)
    #         assert s.supports[key] == witnesses[key]
    # if consequents is not None:
    #     assert s.consequents.keys() == consequents
    #     for key in consequents:
    #         assert isinstance(s.supports[key], constraints.Constraint)
    #         assert s.supports[key] == consequents[key]

    # assert s.supports.keys() == {"uniform_var", "normal_var", "bernoulli_var", "positive_var"}
    # for key in s.supports.keys():
    #     assert isinstance(s.supports[key], constraints.Constraint)
    #     if key == "uniform_var":
    #         assert s.supports[key] == constraints.interval(1, 10)
    #     elif key == "normal_var":
    #         assert s.supports[key] == constraints.real
    #     elif key == "bernoulli_var":
    #         assert s.supports[key] == constraints.boolean
    #     elif key == "positive_var":
    #         assert s.supports[key] == constraints.positive


test_infer_supports(antecedents, witnesses, consequents, mixed_supports_model)