Effect handling seems like a principled way to treat compositions of functions.

In [1]:
import minipyro

In [2]:
fn = lambda x: x**2

msgr = minipyro.Messenger(fn)

In [3]:
msgr(3)

Entering Messenger <minipyro.Messenger object at 0x7fc3b82fc3a0>
Exiting Messenger <minipyro.Messenger object at 0x7fc3b82fc3a0>


9

In [4]:
tr = minipyro.trace(fn)

In [5]:
with minipyro.trace() as test_trace:
    print(test_trace.values())

Entering Messenger <minipyro.trace object at 0x7fc3a902cbe0>
odict_values([])
Exiting Messenger <minipyro.trace object at 0x7fc3a902cbe0>


In [13]:
fn = lambda x: x**2
loc = minipyro.sample("testname", fn, x=3)

In [20]:
import pyro
d = pyro.distributions.Normal(0.0, 1.0)

# take a sample if no effects on stack
loc = minipyro.sample("loc", d)

In [22]:
def model(data=None):
    d = pyro.distributions.Normal(0.0, 1.0)
    
    # take a sample if no effects on stack
    # If stack, apply all effects
    loc = minipyro.sample("loc", d, obs=data)

In [25]:
with minipyro.trace() as capt:
    model()
    print(capt)

Entering Messenger <minipyro.trace object at 0x7fc319cc4ee0>
OrderedDict([('loc', {'type': 'sample', 'name': 'loc', 'fn': Normal(loc: 0.0, scale: 1.0), 'args': (), 'kwargs': {}, 'value': tensor(0.4431)})])
Exiting Messenger <minipyro.trace object at 0x7fc319cc4ee0>


\begin{align}
    \text{weight} & \sim \mathcal{N}\left(\text{guess}, 1.0\right) \\
    \text{measurement}\; |\; \text{weight}, \text{guess} & \sim \mathcal{N}\left(\text{weight}, 0.75\right) \\
\end{align}

In [2]:
import pyro
import pyro.distributions as dist

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))    

This defines a (unnormalized) joint prob dist over `weight` and `measurement` (where `guess` is a hyperparameter):

\begin{align*}
    w, m \; | \; g & \sim p\left(w, m \, \middle| \, g\right) \\
    & = p\left(m\,\middle|\, w, g\right)p\left(w\,\middle|\, g\right) \\
    & = p\left(m\,\middle|\, w\right)p\left(w\,\middle|\, g\right) \\
    & = \text{N}\left(m; w, 0.75\right)\text{N}\left(w; g, 1.0\right) \\
\end{align*}

where we assume $m$ is conditionally independent of $g$ given $w$.

Given the model as a Python function (`scale()` above), we don't have access to the underlying internal distributions. So it's hard to calculate things like the joint (or log-joint) probablility directly.

`poutine` is Pyro's backend of effect handlers, AKA composable building blocks.

In [33]:
import pyro.poutine as poutine

def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint


In [34]:
scale_log_joint = make_log_joint(scale)

In [42]:
import torch

print(scale_log_joint({"measurement": torch.tensor(9.5), "weight": torch.tensor(8.23)}, 8.5))

tensor(-3.0203)


In [12]:
import torch
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger
import pyro
import pyro.distributions as dist

def scale_joint_distribution(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))    

def make_log_joint2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                model(*args, **kwargs)

        trace = tracer.trace
        # print(trace.nodes)
        logp = 0.
        for name, node in trace.nodes.items():
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                    logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint
    

scale_log_joint = make_log_joint2(scale_joint_distribution)

In [13]:
cond_data = {
    "measurement": torch.tensor(9.5),
    "weight": torch.tensor(8.23),
    }
guess = 8.5
scale_log_joint(cond_data, guess)

tensor(-3.0203)

>`poutine.trace` produces a data structure (a Trace) containing a dictionary whose keys are `sample` site names and values are dictionaries containing the distribution ("`fn`") and output ("`value`") at each site, and that the output values at each site are exactly the values specified in `data`.

It looks like this is how observations are "set" in the model.

`Messenger`s are stateful context manager objects that are placed on a global stack and send messages up and down that stack at each effectful operation (e.g., `pyro.sample` call).