## Poutine: A Guide to Programming with Effect Handlers

In [4]:
!pip install pyro-ppl

Collecting pyro-ppl
[?25l  Downloading https://files.pythonhosted.org/packages/8b/0e/0523cb040c8f3ee8644b4280f6a72ed598ac7864680b667d6052fb5d445a/pyro-ppl-0.3.4.tar.gz (262kB)
[K     |████████████████████████████████| 266kB 2.8MB/s 
Collecting opt_einsum>=2.3.2 (from pyro-ppl)
[?25l  Downloading https://files.pythonhosted.org/packages/f6/d6/44792ec668bcda7d91913c75237314e688f70415ab2acd7172c845f0b24f/opt_einsum-2.3.2.tar.gz (59kB)
[K     |████████████████████████████████| 61kB 20.7MB/s 
Collecting tqdm>=4.31 (from pyro-ppl)
[?25l  Downloading https://files.pythonhosted.org/packages/9f/3d/7a6b68b631d2ab54975f3a4863f3c4e9b26445353264ef01f465dc9b0208/tqdm-4.32.2-py2.py3-none-any.whl (50kB)
[K     |████████████████████████████████| 51kB 15.0MB/s 
[?25hBuilding wheels for collected packages: pyro-ppl, opt-einsum
  Building wheel for pyro-ppl (setup.py) ... [?25l[?25hdone
  Stored in directory: /root/.cache/pip/wheels/d4/de/b5/88300d2adc973a7ec963b339d2935d34a0cf02c08b613a8a5e
  Bui

In [0]:
import torch

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

from pyro.poutine.runtime import effectful

pyro.set_rng_seed(101)

### A first example for joint probability distribution inference

- This model below defines a joint probability distribution over "weight" and "measurement":
  - weight|guess ~ Normal(guess, 1)
  - measurement|guess|weight ~ Normal(weight, 0.75)



In [0]:
def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

- If we had access to the inputs and outputs of each pyro.sample site, we could compute their log-joint:
```
# This is formatted as code
logp = dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()
```



### A first look at Poutine: Pyro’s library of algorithmic building blocks

- Poutine is an Effect handlers library provided in Pyro
- Compose two existing effect handlers first:
  - poutine.condition: sets output values of pyro.sample statements
  - poutine.trace: records the inputs, distributions, and outputs of pyro.sample statements

- conditionMessenger class


In [0]:
"""
Do NOT run this chunk - just for reference
"""

# Adds values at observe sites to condition on data and override sampling
class ConditionMessenger(Messenger):
    def __init__(self, data):
        # data here would be a dictionary or a Trace
        super(ConditionMessenger, self).__init__()
        self.data = data

    def _pyro_sample(self, msg):
        # msg here would be current message at a trace site
        # returns a sample from the stochastic function at the site

        # If msg["name"] appears in self.data, 
        # convert the sample site into an observe site
        # whose observed value is the value from self.data[msg["name"]].
        # Otherwise, implements default sampling behavior
        # with no additional effects.

        name = msg["name"]
        if name in self.data:
            assert not msg["is_observed"], \
                "should not change values of existing observes"
            if isinstance(self.data, Trace):
                msg["value"] = self.data.nodes[name]["value"]
            else:
                msg["value"] = self.data[name]
            msg["is_observed"] = True
        return None
      
    def _pyro_param(self, msg):
        return None

In [7]:
def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

scale_log_joint = make_log_joint(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))

tensor(-3.0203)


In [8]:
# Further explanations for chunk above

# Poutine.trace and poutine.condition are wrappers for context managers 
# that presumably communicate with the model through something inside pyro.sample

# Poutine.trace produces a data structure (a Trace) containing a dictionary 
# whose keys are sample site names and values are dictionaries 
# containing the distribution ("fn") and output ("value") at each site

# Again to reminder: 
# A Messenger is placed at the bottom of the stack when its enter method is called, 
# i.e. when it is used in a “with” statement


from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                # here sample "weight" and "measurement" in model "scale"
                model(*args, **kwargs)

        trace = tracer.trace
        logp = 0.
        # here trace records all msg regardless of type
        for name, node in trace.nodes.items():
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint

scale_log_joint = make_log_joint_2(scale)

# here 8.5 below is the input for guess in model "scale"
# dictionary is the input for conditionMessenger

# explanation: if our cond_data provides all values then use given values
# to calculate the log_probs. Otherwise, we would follow the sampled values before - conditionMessenger operates
print(scale_log_joint({}, 8.5))
print(scale_log_joint({"measurement": 9.5}, 8.5))
scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5)

tensor(-2.8493)
tensor(-3.1514)


tensor(-3.0203)

### Implementing new effect handlers with the Messenger API - a more complicated user-defined Messenger

In [9]:
# __init__ and __call__ in python
class A:
  def __init__(self):
    print("init")
  def __call__(self):
    print("call")
    
# happen during initialization
a = A()
# happen when the class is called
a()

init
call


### A tip for using clone() below in the LogJointMessenger

- tensor.detach() creates a tensor that shares storage with tensor that does not require grad. 
- tensor.clone()creates a copy of tensor that imitates the original tensor's requires_grad field.
- use detach() when attempting to remove a tensor from a computation graph, and clone as a way to copy the tensor while still keeping the copy as a part of the computation graph it came from.
- tensor.data returns a new tensor that shares storage with tensor. However, it always has requires_grad=False (even if the original tensor had requires_grad=True

In [10]:
class LogJointMessenger(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    # __call__ is syntactic sugar for using Messengers as higher-order functions.
    # Messenger already defines __call__, but we re-define it here
    # for exposition and to change the return value:
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            # this with command would call the messenger itself
            # to handle messages from running "fn"
            with self:
                fn(*args, **kwargs)
                # return self.logp
                return self.logp.clone()
        return _fn
    
    # Always override __enter__ and __exit__ when using new Messenger!
    
    def __enter__(self):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__enter__()
        # in their __enter__ methods
        # __enter__ would push Messenger itself to the bottom of the stack
        return super(LogJointMessenger, self).__enter__()

    # __exit__ takes the same arguments in all Python context managers
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__exit__ method
        # in their __exit__ methods.
        return super(LogJointMessenger, self).__exit__(exc_type, exc_value, traceback)

    # _pyro_sample will be called once per pyro.sample site.
    # It takes a dictionary msg containing the name, distribution,
    # observation or sample value, and other metadata from the sample site.
    # work as __process__messsage__ in class Messenger
    def _pyro_sample(self, msg):
        assert msg["name"] in self.data
        msg["value"] = self.data[msg["name"]]
        # Since we've observed a value for this site, we set the "is_observed" flag to True
        # This tells any other Messengers not to overwrite msg["value"] with a sample.
        msg["is_observed"] = True
        # "scale": torch.tensor(1.) - a key in msg dic
        # Multiplicative scale factor that can be applied to each site's log_prob
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()

        
# add the LogJointMessenger into the handler stack to process all messages generated during model "scale"
with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    # print(m.logp)
    print(m.logp.clone())

scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))

tensor(-3.0203)
tensor(-3.0203)


In [11]:
# A common way to use LogJointMessenger as a context wrapper with in a function
def log_joint(model=None, cond_data=None):
    msngr = LogJointMessenger(cond_data=cond_data)
    return msngr(model) if model is not None else msngr

# All msgs generated in model "scale" must be in keys provided in cond_data
scale_log_joint = log_joint(scale, cond_data={"measurement": 9.5, "weight": 8.23})
print(scale_log_joint(8.5))

tensor(-3.0203)


### Extension to the LogJointMessenger example


In [12]:
class LogJointMessenger2(poutine.messenger.Messenger):
  
    def __init__(self, cond_data):
        self.data = cond_data
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn
    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super(LogJointMessenger2, self).__enter__()
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super(LogJointMessenger2, self).__exit__(exc_type, exc_value, traceback)  
      
    def _pyro_sample(self, msg):
        if msg["name"] in self.data:
            msg["value"] = self.data[msg["name"]]
            msg["done"] = True
    
    # necessary because some effects can only be applied 
    # after all other effect handlers have had a chance to update the message once
    def _pyro_post_sample(self, msg):
        assert msg["done"]  # the "done" flag asserts that no more modifications to value and fn will be performed.
        print(msg["name"])
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()
        
with LogJointMessenger2(cond_data={"measurement": 9.5}) as m:
    # Here weight is not the cond_data dic but its msg["done"] is True
    # after being operated by other Messengers or default when stack is empty?
    scale(8.5)
    print(m.logp)

with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp)

weight
measurement
tensor(-1.8835)
weight
measurement
tensor(-3.0203)


### Inside the messages sent by Messengers

In [13]:
msg = {
    # The following fields contain the name, inputs, function, and output of a site.
    # These are generally the only fields you'll need to think about.
    "name": "x",
    "fn": dist.Bernoulli(0.5),
    "value": None,  # msg["value"] will eventually contain the value returned by pyro.sample
    "is_observed": False,  # because obs=None by default; only used by sample sites
    
    "args": (),  # positional arguments passed to "fn" when it is called; usually empty for sample sites
    "kwargs": {},  # keyword arguments passed to "fn" when it is called; usually empty for sample sites
    
    # This field typically contains metadata needed or stored by a particular inference algorithm
    "infer": {"enumerate": "parallel"},
    
    # The remaining fields are generally only used by Pyro's internals,
    # or for implementing more advanced effects beyond the scope of this tutorial
    "type": "sample",  # label used by Messenger._process_message to dispatch, in this case to _pyro_sample
    "done": False,
    "stop": False,
    "scale": torch.tensor(1.),  # Multiplicative scale factor that can be applied to each site's log_prob
    "mask": None,
    "continuation": None,
    "cond_indep_stack": (),  # Will contain metadata from each pyro.plate enclosing this sample site.
}
pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, obs=None)

tensor(1.)

### Implementing inference algorithms with existing effect handlers: examples

### Example1: Variational inference with a Monte Carlo ELBO

- ELBO training attached in the Mini Pyro page

In [0]:
def monte_carlo_elbo(model, guide, batch, *args, **kwargs):
    # assuming batch is a dictionary, we use poutine.condition to fix values of observed variables
    conditioned_model = poutine.condition(model, data=batch)

    # we'll approximate the expectation in the ELBO with a single sample:
    # first, we run the guide forward unmodified and record values and distributions
    # at each sample site using poutine.trace
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)

    # we use poutine.replay to set the values of latent variables in the model
    # to the values sampled above by our guide, and use poutine.trace
    # to record the distributions that appear at each sample site in in the model
    model_trace = poutine.trace(poutine.replay(conditioned_model, 
                                               trace=guide_trace)).get_trace(*args, **kwargs)
    
    elbo = 0.
    for name, node in model_trace.nodes.items():
        if node["type"] == "sample":
            elbo = elbo + node["fn"].log_prob(node["value"]).sum()
            if not node["is_observed"]:
                elbo = elbo - guide_trace.nodes[name]["fn"].log_prob(node["value"]).sum()
    return -elbo

  
# use poutine.trace and poutine.block to record pyro.param calls for optimization
def train(model, guide, data):
    optimizer = pyro.optim.Adam({})
    for batch in data:
        # this poutine.trace will record all of the parameters that appear in the model and guide
        # during the execution of monte_carlo_elbo
        with poutine.trace() as param_capture:
            # we use poutine.block here so that only parameters appear in the trace above
            with poutine.block(hide_fn=lambda node: node["type"] != "param"):
                loss = monte_carlo_elbo(model, guide, batch)

        loss.backward()
        params = set(node["value"].unconstrained()
                     for node in param_capture.trace.nodes.values())
        optimizer.step(params)
        pyro.infer.util.zero_grads(params)

### Example2: Exact inference via sequential enumeration

- This example uses poutine.queue, itself implemented using poutine.trace, poutine.replay, and poutine.block, to enumerate over possible values of all discrete variables in a model and compute a marginal distribution over all possible return values or the possible values at a particular sample site.

In [0]:
# Messenger that does a nonlocal exit by raising a util.NonlocalExit exception
class EscapeMessenger(Messenger):
  
    def __init__(self, escape_fn):
        # escape_fn: function that takes a msg as input and returns True
        # if the poutine should perform a nonlocal exit at that site.
        super(EscapeMessenger, self).__init__()
        self.escape_fn = escape_fn

    def _pyro_sample(self, msg):
        # returns a sample from the stochastic function at the site.
        # Evaluates self.escape_fn on the site (self.escape_fn(msg)).
        # If this returns True, raises an exception NonlocalExit(msg).
        # Else, implements default _pyro_sample behavior with no additional effects.
        if self.escape_fn(msg):
            msg["done"] = True
            msg["stop"] = True

            def cont(m):
                raise NonlocalExit(m)
            msg["continuation"] = cont
        return None

# Remain to Check!!

In [18]:
# Some preparation

# Initialize queue
temp = queue.Queue(10)
# Insert Element
temp.put(2)
# Get And remove the element
temp.get()

# functool.partial
# keep part of args and keywords and extend any updates
def partial(func, *args, **keywords):
    def newfunc(*fargs, **fkeywords):
        newkeywords = keywords.copy()
        newkeywords.update(fkeywords)
        return func(*(args + fargs), **newkeywords)
    newfunc.func = func
    newfunc.args = args
    newfunc.keywords = keywords
    return newfunc

# poutine.queue function for sequential enumeration over discrete variables
# Given a stochastic function and a queue,
# return a return value from a complete trace in the queue.
def queue(fn=None, queue=None, max_tries=None,
          extend_fn=None, escape_fn=None, num_samples=None):
    """
    :param fn: a stochastic function (callable containing Pyro primitive calls)
    :param queue: a queue data structure like multiprocessing.Queue to hold partial traces
    :param max_tries: maximum number of attempts to compute a single complete trace
    :param extend_fn: function (possibly stochastic) that takes a partial trace and a site,
        and returns a list of extended traces
    :param escape_fn: function (possibly stochastic) that takes a partial trace and a site,
        and returns a boolean value to decide whether to exit
    :param num_samples: optional number of extended traces for extend_fn to return
    :returns: stochastic function decorated with poutine logic
    """
    if max_tries is None:
        max_tries = int(1e6)
    if extend_fn is None:
        extend_fn = util.enum_extend
    if escape_fn is None:
        escape_fn = util.discrete_escape
    if num_samples is None:
        num_samples = -1

    def wrapper(wrapped):
        def _fn(*args, **kwargs):
            for i in range(max_tries):
                assert not queue.empty(), \
                    "trying to get() from an empty queue will deadlock"
                next_trace = queue.get()
                try:
                    ftr = trace(escape(replay(wrapped, trace=next_trace), escape_fn=functools.partial(escape_fn,next_trace)))
                    return ftr(*args, **kwargs)
                except NonlocalExit as site_container:
                    site_container.reset_stack()
                    for tr in extend_fn(ftr.trace.copy(),site_container.site,num_samples=num_samples):
                        queue.put(tr)
                        
            raise ValueError("max tries ({}) exceeded".format(str(max_tries)))
        return _fn
    return wrapper(fn) if fn is not None else wrapper

2

In [0]:
def sequential_discrete_marginal(model, data, site_name="_RETURN"):

    from six.moves import queue  # queue data structures
    q = queue.Queue()  # Instantiate a first-in first-out queue
    q.put(poutine.Trace())  # seed the queue with an empty trace

    # as before, we fix the values of observed random variables with poutine.condition
    # assuming data is a dictionary whose keys are names of sample sites in model
    conditioned_model = poutine.condition(model, data=data)

    # we wrap the conditioned model in a poutine.queue,
    # which repeatedly pushes and pops partially completed executions from a Queue()
    # to perform breadth-first enumeration over the set of values of all discrete sample sites in model
    enum_model = poutine.queue(conditioned_model, queue=q)

    # actually perform the enumeration by repeatedly tracing enum_model
    # and accumulate samples and trace log-probabilities for postprocessing
    samples, log_weights = [], []
    while not q.empty():
        trace = poutine.trace(enum_model).get_trace()
        samples.append(trace.nodes[site_name]["value"])
        log_weights.append(trace.log_prob_sum())

    # we take the samples and log-joints and turn them into a histogram:
    samples = torch.stack(samples, 0)
    log_weights = torch.stack(log_weights, 0)
    log_weights = log_weights - dist.util.logsumexp(log_weights, dim=0)
    # Empirical distribution associated with the sampled data
    return dist.Empirical(samples, log_weights)

### Example3: implementing lazy evaluation with the Messenger API

In [20]:
class Foo:
  a = 5
fooInstance = Foo()
print(isinstance(fooInstance, Foo))
print(isinstance(fooInstance, (list, tuple)))
print(isinstance(fooInstance, (list, tuple, Foo)))

# isinstance(object, classinfo)
# object - object to be checked
# classinfo - class, type, or tuple of classes and types
# True if the object is an instance or subclass of a class, or any element of the tuple


True
False
True


In [0]:
# first define a LazyValue class that we will use to build up a computation graph

# With LazyValue, implementing lazy evaluation as a Messenger compatible with other effect handlers is suprisingly easy. 
# We just make each msg["value"] a LazyValue and introduce a new operation type "apply" for deterministic operations
class LazyValue(object):
    def __init__(self, fn, *args, **kwargs):
        self._expr = (fn, args, kwargs)
        self._value = None
        
    def __str__(self):
        return "({} {})".format(str(self._expr[0]), " ".join(map(str, self._expr[1])))

    def evaluate(self):
        if self._value is None:
            fn, args, kwargs = self._expr
            
            fn = fn.evaluate() if isinstance(fn, LazyValue) else fn
            args = tuple(arg.evaluate() if isinstance(arg, LazyValue) else arg for arg in args)
            kwargs = {k: v.evaluate() if isinstance(v, LazyValue) else v for k, v in kwargs.items()}
            
            self._value = fn(*args, **kwargs)
        return self._value

class LazyMessenger(pyro.poutine.messenger.Messenger):
    def _process_message(self, msg):
        if msg["type"] in ("apply", "sample") and not msg["done"]:
            msg["done"] = True
            msg["value"] = LazyValue(msg["fn"], *msg["args"], **msg["kwargs"])

In [0]:
@effectful(type="apply")
def add(x, y):
    return x + y

@effectful(type="apply")
def mul(x, y):
    return x * y

@effectful(type="apply")
def sigmoid(x):
    return torch.sigmoid(x)

@effectful(type="apply")
def normal(loc, scale):
    return dist.Normal(loc, scale)

In [23]:
# Applied to another model
def biased_scale(guess):
    weight = pyro.sample("weight", normal(guess, 1.))
    tolerance = pyro.sample("tolerance", normal(0., 0.25))
    return pyro.sample("measurement", normal(add(mul(weight, 0.8), 1.), sigmoid(tolerance)))

with LazyMessenger():
    v = biased_scale(8.5)
    print(v)
    print(v.evaluate())

((<function normal at 0x7fca60dac8c8> (<function add at 0x7fca60dacf28> (<function mul at 0x7fca60dace18> ((<function normal at 0x7fca60dac8c8> 8.5 1.0) ) 0.8) 1.0) (<function sigmoid at 0x7fca60dacd08> ((<function normal at 0x7fca60dac8c8> 0.0 0.25) ))) )
tensor(8.7122)
