## Poutine: A Guide to Programming with Effect Handlers

In [2]:
!pip install pyro-ppl

Collecting pyro-ppl
[?25l  Downloading https://files.pythonhosted.org/packages/8b/0e/0523cb040c8f3ee8644b4280f6a72ed598ac7864680b667d6052fb5d445a/pyro-ppl-0.3.4.tar.gz (262kB)
[K     |████████████████████████████████| 266kB 5.1MB/s 
Collecting opt_einsum>=2.3.2 (from pyro-ppl)
[?25l  Downloading https://files.pythonhosted.org/packages/f6/d6/44792ec668bcda7d91913c75237314e688f70415ab2acd7172c845f0b24f/opt_einsum-2.3.2.tar.gz (59kB)
[K     |████████████████████████████████| 61kB 8.8MB/s 
Collecting tqdm>=4.31 (from pyro-ppl)
[?25l  Downloading https://files.pythonhosted.org/packages/9f/3d/7a6b68b631d2ab54975f3a4863f3c4e9b26445353264ef01f465dc9b0208/tqdm-4.32.2-py2.py3-none-any.whl (50kB)
[K     |████████████████████████████████| 51kB 19.8MB/s 
[?25hBuilding wheels for collected packages: pyro-ppl, opt-einsum
  Building wheel for pyro-ppl (setup.py) ... [?25l[?25hdone
  Created wheel for pyro-ppl: filename=pyro_ppl-0.3.4-cp36-none-any.whl size=365502 sha256=4299f40936407c7a898e50

In [0]:
import torch

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

from pyro.poutine.runtime import effectful

pyro.set_rng_seed(101)

### A first example for joint probability distribution inference

- This model below defines a joint probability distribution over "weight" and "measurement":
  - weight|guess ~ Normal(guess, 1)
  - measurement|guess|weight ~ Normal(weight, 0.75)



In [0]:
def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

- If we had access to the inputs and outputs of each pyro.sample site, we could compute their log-joint:
```
# This is formatted as code
logp = dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()
```



### A first look at Poutine: Pyro’s library of algorithmic building blocks

- Poutine is an Effect handlers library provided in Pyro
- Compose two existing effect handlers first:
  - poutine.condition: sets output values of pyro.sample statements
  - poutine.trace: records the inputs, distributions, and outputs of pyro.sample statements

- conditionMessenger class


In [0]:
"""
Do NOT run this chunk - just for reference
"""

# Adds values at observe sites to condition on data and override sampling
class ConditionMessenger(Messenger):
    def __init__(self, data):
        # data here would be a dictionary or a Trace
        super(ConditionMessenger, self).__init__()
        self.data = data

    def _pyro_sample(self, msg):
        # msg here would be current message at a trace site
        # returns a sample from the stochastic function at the site

        # If msg["name"] appears in self.data, 
        # convert the sample site into an observe site
        # whose observed value is the value from self.data[msg["name"]].
        # Otherwise, implements default sampling behavior
        # with no additional effects.

        name = msg["name"]
        if name in self.data:
            assert not msg["is_observed"], \
                "should not change values of existing observes"
            if isinstance(self.data, Trace):
                msg["value"] = self.data.nodes[name]["value"]
            else:
                msg["value"] = self.data[name]
            msg["is_observed"] = True
        return None
      
    def _pyro_param(self, msg):
        return None

In [5]:
def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

scale_log_joint = make_log_joint(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))

tensor(-3.0203)


In [6]:
# Further explanations for chunk above

# Poutine.trace and poutine.condition are wrappers for context managers 
# that presumably communicate with the model through something inside pyro.sample

# Poutine.trace produces a data structure (a Trace) containing a dictionary 
# whose keys are sample site names and values are dictionaries 
# containing the distribution ("fn") and output ("value") at each site

# Again to reminder: 
# A Messenger is placed at the bottom of the stack when its enter method is called, 
# i.e. when it is used in a “with” statement


from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                # here sample "weight" and "measurement" in model "scale"
                model(*args, **kwargs)

        trace = tracer.trace
        logp = 0.
        # here trace records all msg regardless of type
        for name, node in trace.nodes.items():
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint

scale_log_joint = make_log_joint_2(scale)

# here 8.5 below is the input for guess in model "scale"
# dictionary is the input for conditionMessenger

# explanation: if our cond_data provides all values then use given values
# to calculate the log_probs. Otherwise, we would follow the sampled values before - conditionMessenger operates
print(scale_log_joint({}, 8.5))
print(scale_log_joint({"measurement": 9.5}, 8.5))
scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5)

tensor(-2.8493)
tensor(-3.1514)


tensor(-3.0203)

### Implementing new effect handlers with the Messenger API - a more complicated user-defined Messenger

In [7]:
# __init__ and __call__ in python
class A:
  def __init__(self):
    print("init")
  def __call__(self):
    print("call")
    
# happen during initialization
a = A()
# happen when the class is called
a()

init
call


### A tip for using clone() below in the LogJointMessenger

- tensor.detach() creates a tensor that shares storage with tensor that does not require grad. 
- tensor.clone()creates a copy of tensor that imitates the original tensor's requires_grad field.
- use detach() when attempting to remove a tensor from a computation graph, and clone as a way to copy the tensor while still keeping the copy as a part of the computation graph it came from.
- tensor.data returns a new tensor that shares storage with tensor. However, it always has requires_grad=False (even if the original tensor had requires_grad=True

In [8]:
class LogJointMessenger(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    # __call__ is syntactic sugar for using Messengers as higher-order functions.
    # Messenger already defines __call__, but we re-define it here
    # for exposition and to change the return value:
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            # this with command would call the messenger itself
            # to handle messages from running "fn"
            with self:
                fn(*args, **kwargs)
                # return self.logp
                return self.logp.clone()
        return _fn
    
    # Always override __enter__ and __exit__ when using new Messenger!
    
    def __enter__(self):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__enter__()
        # in their __enter__ methods
        # __enter__ would push Messenger itself to the bottom of the stack
        return super(LogJointMessenger, self).__enter__()

    # __exit__ takes the same arguments in all Python context managers
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__exit__ method
        # in their __exit__ methods.
        return super(LogJointMessenger, self).__exit__(exc_type, exc_value, traceback)

    # _pyro_sample will be called once per pyro.sample site.
    # It takes a dictionary msg containing the name, distribution,
    # observation or sample value, and other metadata from the sample site.
    # work as __process__messsage__ in class Messenger
    def _pyro_sample(self, msg):
        assert msg["name"] in self.data
        msg["value"] = self.data[msg["name"]]
        # Since we've observed a value for this site, we set the "is_observed" flag to True
        # This tells any other Messengers not to overwrite msg["value"] with a sample.
        msg["is_observed"] = True
        # "scale": torch.tensor(1.) - a key in msg dic
        # Multiplicative scale factor that can be applied to each site's log_prob
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()

        
# add the LogJointMessenger into the handler stack to process all messages generated during model "scale"
with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    # print(m.logp)
    print(m.logp.clone())

scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))

tensor(-3.0203)
tensor(-3.0203)


In [9]:
# A common way to use LogJointMessenger as a context wrapper with in a function
def log_joint(model=None, cond_data=None):
    msngr = LogJointMessenger(cond_data=cond_data)
    return msngr(model) if model is not None else msngr

# All msgs generated in model "scale" must be in keys provided in cond_data
scale_log_joint = log_joint(scale, cond_data={"measurement": 9.5, "weight": 8.23})
print(scale_log_joint(8.5))

tensor(-3.0203)


### Extension to the LogJointMessenger example


In [10]:
class LogJointMessenger2(poutine.messenger.Messenger):
  
    def __init__(self, cond_data):
        self.data = cond_data
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn
    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super(LogJointMessenger2, self).__enter__()
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super(LogJointMessenger2, self).__exit__(exc_type, exc_value, traceback)  
      
    def _pyro_sample(self, msg):
        if msg["name"] in self.data:
            msg["value"] = self.data[msg["name"]]
            msg["done"] = True
    
    # necessary because some effects can only be applied 
    # after all other effect handlers have had a chance to update the message once
    def _pyro_post_sample(self, msg):
        assert msg["done"]  # the "done" flag asserts that no more modifications to value and fn will be performed.
        print(msg["name"])
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()
        
with LogJointMessenger2(cond_data={"measurement": 9.5}) as m:
    # Here weight is not the cond_data dic but its msg["done"] is True
    # after being operated by other Messengers or default when stack is empty?
    scale(8.5)
    print(m.logp)

with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp)

weight
measurement
tensor(-1.8835)
weight
measurement
tensor(-3.0203)


### Inside the messages sent by Messengers

In [11]:
msg = {
    # The following fields contain the name, inputs, function, and output of a site.
    # These are generally the only fields you'll need to think about.
    "name": "x",
    "fn": dist.Bernoulli(0.5),
    "value": None,  # msg["value"] will eventually contain the value returned by pyro.sample
    "is_observed": False,  # because obs=None by default; only used by sample sites
    
    "args": (),  # positional arguments passed to "fn" when it is called; usually empty for sample sites
    "kwargs": {},  # keyword arguments passed to "fn" when it is called; usually empty for sample sites
    
    # This field typically contains metadata needed or stored by a particular inference algorithm
    "infer": {"enumerate": "parallel"},
    
    # The remaining fields are generally only used by Pyro's internals,
    # or for implementing more advanced effects beyond the scope of this tutorial
    "type": "sample",  # label used by Messenger._process_message to dispatch, in this case to _pyro_sample
    "done": False,
    "stop": False,
    "scale": torch.tensor(1.),  # Multiplicative scale factor that can be applied to each site's log_prob
    "mask": None,
    "continuation": None,
    "cond_indep_stack": (),  # Will contain metadata from each pyro.plate enclosing this sample site.
}
pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, obs=None)

tensor(1.)

## Implementing inference algorithms with existing effect handlers: examples

### Example1: Variational inference with a Monte Carlo ELBO

- ELBO training attached in the Mini Pyro page

In [0]:
def monte_carlo_elbo(model, guide, batch, *args, **kwargs):
    # assuming batch is a dictionary, we use poutine.condition to fix values of observed variables
    conditioned_model = poutine.condition(model, data=batch)

    # we'll approximate the expectation in the ELBO with a single sample:
    # first, we run the guide forward unmodified and record values and distributions
    # at each sample site using poutine.trace
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)

    # we use poutine.replay to set the values of latent variables in the model
    # to the values sampled above by our guide, and use poutine.trace
    # to record the distributions that appear at each sample site in in the model
    model_trace = poutine.trace(poutine.replay(conditioned_model, 
                                               trace=guide_trace)).get_trace(*args, **kwargs)
    
    elbo = 0.
    for name, node in model_trace.nodes.items():
        if node["type"] == "sample":
            elbo = elbo + node["fn"].log_prob(node["value"]).sum()
            if not node["is_observed"]:
                elbo = elbo - guide_trace.nodes[name]["fn"].log_prob(node["value"]).sum()
    return -elbo

  
# use poutine.trace and poutine.block to record pyro.param calls for optimization
def train(model, guide, data):
    optimizer = pyro.optim.Adam({})
    for batch in data:
        # this poutine.trace will record all of the parameters that appear in the model and guide
        # during the execution of monte_carlo_elbo
        with poutine.trace() as param_capture:
            # we use poutine.block here so that only parameters appear in the trace above
            with poutine.block(hide_fn=lambda node: node["type"] != "param"):
                loss = monte_carlo_elbo(model, guide, batch)

        loss.backward()
        params = set(node["value"].unconstrained()
                     for node in param_capture.trace.nodes.values())
        optimizer.step(params)
        pyro.infer.util.zero_grads(params)

### Example2: Exact inference via sequential enumeration

- This example uses poutine.queue, itself implemented using poutine.trace, poutine.replay, and poutine.block, to enumerate over possible values of all discrete variables in a model and compute a marginal distribution over all possible return values or the possible values at a particular sample site.

In [0]:
# Used by poutine.EscapeMessenger to return site information
# reset_stack()
# Reset the state of the frames remaining in the stack. Necessary for multiple re-executions in poutine.queue

# am_i_wrapped()
# Checks whether the current computation is wrapped in a poutine. :returns: bool

# apply_stack(initial_msg)
# Execute the effect stack at a single site

class NonlocalExit(Exception):
    def __init__(self, site, *args, **kwargs):
        super(NonlocalExit, self).__init__(*args, **kwargs)
        self.site = site
        
# Messenger that does a nonlocal exit by raising a util.NonlocalExit exception
class EscapeMessenger(Messenger):
  
    def __init__(self, escape_fn):
        # escape_fn: function that takes a msg as input and returns True
        # if the poutine should perform a nonlocal exit at that site.
        super(EscapeMessenger, self).__init__()
        self.escape_fn = escape_fn

    def _pyro_sample(self, msg):
        # returns a sample from the stochastic function at the site.
        # Evaluates self.escape_fn on the site (self.escape_fn(msg)).
        # If this returns True, raises an exception NonlocalExit(msg).
        # Else, implements default _pyro_sample behavior with no additional effects.
        if self.escape_fn(msg):
            msg["done"] = True
            msg["stop"] = True

            def cont(m):
                raise NonlocalExit(m)
            msg["continuation"] = cont
        return None

In [14]:
import queue
from pyro.infer import SVI, Trace_ELBO
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

# Some preparation

# Initialize queue
temp = queue.Queue(10)
# Insert Element
temp.put(2)
# Get And remove the element
print(temp.get())

# functool.partial
# keep part of args and keywords and extend any updates
def partial(func, *args, **keywords):
  
    def newfunc(*fargs, **fkeywords):
        newkeywords = keywords.copy()
        newkeywords.update(fkeywords)
        return func(*(args + fargs), **newkeywords)
      
    newfunc.func = func
    newfunc.args = args
    newfunc.keywords = keywords
    return newfunc

# replay(fn=None, trace=None, params=None)
# Given a callable that contains Pyro primitive calls, return a callable that runs the original, 
# reusing the values at sites in trace at those sites in the new trace
def model(x):
  s = pyro.param("s", torch.tensor(0.5))
  z = pyro.sample("z", dist.Normal(x, s))
  return z ** 2

old_trace = poutine.trace(model).get_trace(1.0)
replayed_model = poutine.replay(model, trace=old_trace)
print(replayed_model(0.0))
bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"])

2
tensor(1.9962, grad_fn=<PowBackward0>)


True

In [0]:
# enumerate over possible values of all discrete variables in a model 
# and compute a marginal distribution over all possible return values or the possible values at a particular sample site

# poutine.queue function for sequential enumeration over discrete variables
# Given a stochastic function and a queue,
# return a return value from a complete trace in the queue.

def queue(fn=None, queue=None, max_tries=None,
          extend_fn=None, escape_fn=None, num_samples=None):
    """
    :param fn: a stochastic function (callable containing Pyro primitive calls)
    :param queue: a queue data structure like multiprocessing.Queue to hold partial traces
    :param max_tries: maximum number of attempts to compute a single complete trace
    :param extend_fn: function (possibly stochastic) that takes a partial trace and a site,
        and returns a list of extended traces
    :param escape_fn: function (possibly stochastic) that takes a partial trace and a site,
        and returns a boolean value to decide whether to exit
    :param num_samples: optional number of extended traces for extend_fn to return
    :returns: stochastic function decorated with poutine logic
    """
    # initialization
    if max_tries is None:
        max_tries = int(1e6)
    if extend_fn is None:
        extend_fn = util.enum_extend
    if escape_fn is None:
        escape_fn = util.discrete_escape
    if num_samples is None:
        num_samples = -1

    def wrapper(wrapped):
        def _fn(*args, **kwargs):
            for i in range(max_tries):
                # get next trace from the queue
                assert not queue.empty(), \
                    "trying to get() from an empty queue will deadlock"
                next_trace = queue.get()
                
                try:
                    ftr = trace(escape(replay(wrapped, trace=next_trace), escape_fn=functools.partial(escape_fn,next_trace)))
                    return ftr(*args, **kwargs)
                except NonlocalExit as site_container:
                    site_container.reset_stack()
                    for tr in extend_fn(ftr.trace.copy(),site_container.site,num_samples=num_samples):
                        queue.put(tr)
                        
            raise ValueError("max tries ({}) exceeded".format(str(max_tries)))
        return _fn
    return wrapper(fn) if fn is not None else wrapper

In [0]:
def sequential_discrete_marginal(model, data, site_name="_RETURN"):

    from six.moves import queue  # queue data structures
    q = queue.Queue()  # Instantiate a first-in first-out queue
    q.put(poutine.Trace())  # seed the queue with an empty trace

    # as before, we fix the values of observed random variables with poutine.condition
    # assuming data is a dictionary whose keys are names of sample sites in model
    conditioned_model = poutine.condition(model, data=data)

    # we wrap the conditioned model in a poutine.queue,
    # which repeatedly pushes and pops partially completed executions from a Queue()
    # to perform breadth-first enumeration over the set of values of all discrete sample sites in model
    enum_model = poutine.queue(conditioned_model, queue=q)

    # actually perform the enumeration by repeatedly tracing enum_model
    # and accumulate samples and trace log-probabilities for postprocessing
    samples, log_weights = [], []
    while not q.empty():
        # record the sample value and log weight at this trace site
        trace = poutine.trace(enum_model).get_trace()
        samples.append(trace.nodes[site_name]["value"])
        log_weights.append(trace.log_prob_sum())

    # we take the samples and log-joints and turn them into a histogram:
    samples = torch.stack(samples, 0)
    log_weights = torch.stack(log_weights, 0)
    log_weights = log_weights - dist.util.logsumexp(log_weights, dim=0)
    # Empirical distribution associated with the sampled data
    return dist.Empirical(samples, log_weights)

### Example3: implementing lazy evaluation with the Messenger API

In [17]:
class Foo:
  a = 5
fooInstance = Foo()

print(isinstance(fooInstance, Foo))
print(isinstance(fooInstance, (list, tuple)))
print(isinstance(fooInstance, (list, tuple, Foo)))

# isinstance(object, classinfo)
# object - object to be checked
# classinfo - class, type, or tuple of classes and types
# True if the object is an instance or subclass of a class, or any element of the tuple


True
False
True


In [0]:
# first define a LazyValue class that we will use to build up a computation graph

# With LazyValue, implementing lazy evaluation as a Messenger compatible with other effect handlers is suprisingly easy. 
# We just make each msg["value"] a LazyValue and introduce a new operation type "apply" for deterministic operations
class LazyValue(object):
    def __init__(self, fn, *args, **kwargs):
        self._expr = (fn, args, kwargs)
        self._value = None
        
    def __str__(self):
        return "({} {})".format(str(self._expr[0]), " ".join(map(str, self._expr[1])))

    def evaluate(self):
        if self._value is None:
            fn, args, kwargs = self._expr
            
            fn = fn.evaluate() if isinstance(fn, LazyValue) else fn
            print(fn)
            args = tuple(arg.evaluate() if isinstance(arg, LazyValue) else arg for arg in args)
            print(args)
            kwargs = {k: v.evaluate() if isinstance(v, LazyValue) else v for k, v in kwargs.items()}
            print(kwargs.items())
            
            self._value = fn(*args, **kwargs)
        return self._value

class LazyMessenger(pyro.poutine.messenger.Messenger):
    def _process_message(self, msg):
        if msg["type"] in ("apply", "sample") and not msg["done"]:
            print("Print from process message func: ")
            print(msg["name"])
            print(msg["fn"])
            msg["done"] = True
            msg["value"] = LazyValue(msg["fn"], *msg["args"], **msg["kwargs"])

In [0]:
# use pyro.poutine.runtime.effectful as a decorator to expose these operations to LazyMessenger. 
# effectful constructs a message much like the one above and sends it up and down the effect handler stack, 
# but allows us to set the type (in this case, to "apply" instead of "sample") 
# so that these operations aren’t mistaken for sample statements by other effect handlers like TraceMessenger

@effectful(type="apply")
def add(x, y):
    return x + y

@effectful(type="apply")
def mul(x, y):
    return x * y

@effectful(type="apply")
def sigmoid(x):
    return torch.sigmoid(x)

@effectful(type="apply")
def normal(loc, scale):
    return dist.Normal(loc, scale)

In [0]:
# Applied to another model
def biased_scale(guess):
    weight = pyro.sample("weight", normal(guess, 1.))
    tolerance = pyro.sample("tolerance", normal(0., 0.25))
    return pyro.sample("measurement", normal(add(mul(weight, 0.8), 1.), sigmoid(tolerance)))

In [39]:
with LazyMessenger():
    v = biased_scale(8.5)
    print(v)
    print("Now start v evaluate: ")
    print(v.evaluate())

Print from process message func: 
None
<function normal at 0x7fbe3777a620>
Print from process message func: 
weight
(<function normal at 0x7fbe3777a620> 8.5 1.0)
Print from process message func: 
None
<function normal at 0x7fbe3777a620>
Print from process message func: 
tolerance
(<function normal at 0x7fbe3777a620> 0.0 0.25)
Print from process message func: 
None
<function mul at 0x7fbe3777a400>
Print from process message func: 
None
<function add at 0x7fbe3777aa60>
Print from process message func: 
None
<function sigmoid at 0x7fbe3777a510>
Print from process message func: 
None
<function normal at 0x7fbe3777a620>
Print from process message func: 
measurement
(<function normal at 0x7fbe3777a620> (<function add at 0x7fbe3777aa60> (<function mul at 0x7fbe3777a400> ((<function normal at 0x7fbe3777a620> 8.5 1.0) ) 0.8) 1.0) (<function sigmoid at 0x7fbe3777a510> ((<function normal at 0x7fbe3777a620> 0.0 0.25) )))
((<function normal at 0x7fbe3777a620> (<function add at 0x7fbe3777aa60> (<fun

### Some remaining points to test first

In [40]:
# Pyro multivariate Bernoulli distribution 

# A first normal example using pyro distribution
mu = 0; sigma = 1; normal = dist.Normal(mu, sigma); 
x = normal.sample(); 
#compute the log probability according to the distribution
log_prob1 = normal.log_prob(x)
print("sample", x)
print("log prob",log_prob1)

# Now take Bernoulli to have a try
p = 0.5; bernoulli = dist.Bernoulli(p);
y = bernoulli.sample()
log_prob1 = normal.log_prob(y)
print("sample", y)
print("log prob",log_prob1)

sample tensor(-0.1651)
log prob tensor(-0.9326)
sample tensor(0.)
log prob tensor(-0.9189)


### Tensor shapes in Pyro
- Use .expand() to draw a batch of samples, or rely on plate to expand automatically.
- Use my_dist.to_event(1) to declare a dimension as dependent.
- Use with pyro.plate('name', size): to declare a dimension as conditionally independent.
- All dimensions must be declared either dependent or conditionally independent.

In [0]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('0.3.4')
pyro.enable_validation(True)    # <---- This is always a good idea!

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

### Distributions shapes: batch_shape and event_shape
- PyTorch Tensors have a single .shape attribute, but Distributions have two shape attributions with special meaning: .batch_shape and .event_shape. These two combine to define the total shape of a sample
- http://docs.pyro.ai/en/0.2.1-release/_modules/pyro/distributions/torch.html#Bernoulli

In [0]:
# some distribution classes for reference
import torch
from pyro.distributions.torch_distribution import TorchDistributionMixin

# e.g: for Bernoulli distribution, provide either probs or logits to sample
class Bernoulli(torch.distributions.Bernoulli, TorchDistributionMixin):
    def expand(self, batch_shape):
        validate_args = self.__dict__.get('validate_args')
        if 'probs' in self.__dict__:
            probs = self.probs.expand(batch_shape)
            return Bernoulli(probs=probs, validate_args=validate_args)
        else:
            logits = self.logits.expand(batch_shape)
            return Bernoulli(logits=logits, validate_args=validate_args)

class Beta(torch.distributions.Beta, TorchDistributionMixin):
    def expand(self, batch_shape):
        validate_args = self.__dict__.get('validate_args')
        concentration1 = self.concentration1.expand(batch_shape)
        concentration0 = self.concentration0.expand(batch_shape)
        return Beta(concentration1, concentration0, validate_args=validate_args)

- Indices over .batch_shape denote conditionally independent random variables, whereas indices over .event_shape denote dependent random variables (ie one draw from a distribution).
- Because the dependent random variables define probability together, the .log_prob() method only produces a single number for each event of shape .event_shape. Thus the total shape of .log_prob() is .batch_shape:

- Note that the Distribution.sample() method also takes a sample_shape parameter that indexes over independent identically distributed (iid) random varables, so that
```
x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape
```
- For example univariate distributions have empty event shape (because each number is an independent event). Distributions over vectors like MultivariateNormal have len(event_shape) == 1. Distributions over matrices like InverseWishart have len(event_shape) == 2

In [0]:
# Some examples

# The simplest distribution shape is a single univariate distribution.
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()

x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()
assert x.shape == d.batch_shape + d.event_shape
assert d.log_prob(x).shape == d.batch_shape

In [44]:
print(0.5 * torch.ones(3,4))

# Distributions can be batched by passing in batched parameters.
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4) # conditionally independent random variables
# event_shape is empty because for any univariate distribution, 
# samples are independent
assert d.event_shape == ()

x = d.sample()
print(x)
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

tensor([[0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000]])
tensor([[1., 1., 1., 0.],
        [0., 1., 0., 1.],
        [1., 1., 0., 0.]])


In [45]:
# Another way to batch distributions is via the .expand() method. 
# This only works if parameters are identical along the leftmost dimensions.

d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])
print(d)

assert d.batch_shape == (3, 4)
assert d.event_shape == ()

x = d.sample()
print(x)
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

Bernoulli(probs: torch.Size([3, 4]))
tensor([[0., 1., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.]])


In [46]:
# Multivariate distributions have nonempty .event_shape. 
# For these distributions, the shapes of .sample() and .log_prob(x) differ

d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == () # this generates one sample from the distribution
assert d.event_shape == (3,)

x = d.sample()
print(x)

assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape

tensor([0.1487, 1.1159, 0.3977])


### Reshaping distributions

In [47]:
# In Pyro you can treat a univariate distribution as multivariate by calling the .to_event(n) property 
# where n is the number of batch dimensions (from the right) to declare as dependent.

# this would generate 3 samples from a Multivariate Bernoulli distribution (dim=4)
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(1)
assert d.batch_shape == (3,) 
assert d.event_shape == (4,)

x = d.sample()
print(x)

assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,) # each sample return a log_prob

# samples have shape batch_shape + event_shape, whereas .log_prob(x) values have shape batch_shape. 

tensor([[0., 1., 0., 1.],
        [1., 1., 0., 1.],
        [0., 0., 1., 0.]])


In [48]:
def to_event_test(n):
  d = Bernoulli(0.5 * torch.ones(3,4)).to_event(n)
  print(d.batch_shape)
  print(d.event_shape)

  x = d.sample()
  print(x)

  print(x.shape)
  print(d.log_prob(x).shape)
  print("\n")

# No more than 2 dimensions
to_event_test(0) # A complete conditionally independent sampling
to_event_test(1) # Three samples from a Multivariate distribution
to_event_test(2) # One sample from 2-dimensional dependent distribution

torch.Size([3, 4])
torch.Size([])
tensor([[1., 0., 1., 0.],
        [0., 1., 1., 0.],
        [0., 0., 0., 0.]])
torch.Size([3, 4])
torch.Size([3, 4])


torch.Size([3])
torch.Size([4])
tensor([[0., 0., 1., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 1.]])
torch.Size([3, 4])
torch.Size([3])


torch.Size([])
torch.Size([3, 4])
tensor([[1., 0., 1., 1.],
        [0., 0., 1., 1.],
        [1., 1., 1., 1.]])
torch.Size([3, 4])
torch.Size([])




In [49]:
# Often in Pyro we’ll declare some dimensions as dependent even though they are in fact independent
# Here .expand() would help generate 10 samples - batch_shape = 3

x = pyro.sample("x", dist.Normal(0, 1).expand([3]).to_event(1)) # at most 1
print(x)
assert x.shape == (3, ) # here is a MultiVariate Normal distribution - one sample

x = pyro.sample("x", dist.Normal(0, 1).expand([3, 3]).to_event(1)) # at most 2
print(x)
assert x.shape == (3, 3) # here is a Multivariate Normal distribution - 3 samples

tensor([-0.1687,  1.0292, -0.7435])
tensor([[ 1.3619, -0.9674,  0.7154],
        [ 0.2573,  0.3683,  0.9159],
        [ 1.1051,  0.6188,  0.7229]])


### Declaring independent dims with plate

- Use the context manager pyro.plate to declare that certain batch dimensions are independent. Inference algorithms can then take advantage of this independence to e.g. construct lower variance gradient estimators or to enumerate in linear space rather than exponential space. 
- An example of an independent dimension is the index over data in a minibatch: each datum should be independent of all others.
- The simplest way to declare a dimension as independent is to declare the rightmost batch dimension as independent via a simple
```
with pyro.plate("my_plate"):
    # within this context, batch dimension -1 is independent
```

In [0]:
# plate can make use of conditional independence information when estimating gradients

"""
Do NOT run this chunk - just for reference
"""
# count from the right by using negative indices like -2, -1
# providing an optional size argument to aid in debugging shapes
with pyro.plate("my_plate", len(my_data)):
    # within this context, batch dimension -1 is independent

# nest plates for per-pixel independence
with pyro.plate("x_axis", 320):
    # within this context, batch dimension -1 is independent
    with pyro.plate("y_axis", 200):
        # within this context, batch dimensions -2 and -1 are independent

- Finally if you want to mix and match plates for e.g. noise that depends only on x, some noise that depends only on y, and some noise that depends on both, you can declare multiple plates and use them as reusable context managers. 
- In this case Pyro cannot automatically allocate a dimension, so you need to provide a dim argument (again counting from the right):

In [0]:
"""
Do NOT run this chunk - just for reference
"""
x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
    # within this context, batch dimension -2 is independent
with y_axis:
    # within this context, batch dimension -3 is independent
with x_axis, y_axis:
    # within this context, batch dimensions -3 and -2 are independent

In [50]:
# Personal view: 
# for a single number, expand can be skipped/default to match the size 
with pyro.plate("x_plate", 10):
    x = pyro.sample("x", dist.Normal(0, 1))  # .expand([10]) is automatic
    print(x)
    assert x.shape == (10,)

tensor([ 2.0762, -0.3551, -1.5438, -0.1028, -1.2283,  0.6111,  1.9926,  1.1711,
         0.6425,  0.3730])


### Take a closer look at batch sizes within plate

In [51]:
def model1():
    a = pyro.sample("a", Normal(0, 1))
    print(a)
    b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1))
    print(b)
    
    # Otherwise, len() == 2 here is to check the shape matching exactly
    with pyro.plate("c_plate", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))
        print(c)
    with pyro.plate("d_plate", 3):
        # for d here, 2 dimensions from the right would be dependent
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))
        print(d)
        
    assert a.shape == ()       # batch_shape == ()     event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)   event_shape == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5)

    x_axis = pyro.plate("x_axis", 3, dim=-2)
    y_axis = pyro.plate("y_axis", 2, dim=-3)
    # dim (int) – An optional dimension to use for this independence index. 
    # If specified, dim should be negative, i.e. should index from the right. 
    # If not specified, dim is set to the rightmost dim that is left of all enclosing plate contexts.
    
    with x_axis:
        x = pyro.sample("x", Normal(0, 1))
        print(x)
    with y_axis:
        y = pyro.sample("y", Normal(0, 1))
    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0, 1))
        z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
        
    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)

test_model(model1, model1, Trace_ELBO())

tensor(-1.1764)
tensor([0.2359, 1.1323])
tensor([ 0.2678, -1.5047])
tensor([[[-1.1603,  0.0037,  0.5752, -1.8553, -0.2216],
         [-1.2801,  1.0434, -0.2138,  0.9126,  0.3099],
         [-1.1750,  0.9498,  1.7685,  0.9414, -0.9941],
         [ 0.2179,  0.1842,  1.9193, -0.5476, -0.6444]],

        [[ 1.4856, -0.0487, -0.6052, -0.0714, -0.6591],
         [-0.5913,  0.0549,  0.5166, -1.7880, -0.9646],
         [ 2.3067,  0.7811,  0.1099, -2.4119,  0.8760],
         [ 1.0870, -0.5308, -2.2917,  0.3445, -0.3632]],

        [[ 1.3317, -1.5827,  0.1967,  1.2679,  1.8436],
         [ 2.4306, -1.0752,  0.2661, -1.5314, -0.2388],
         [-1.5428, -0.7235,  0.2389,  0.3362, -1.9714],
         [-0.9448, -0.1025, -0.3308, -1.3757, -0.3435]]])
tensor([[-0.9617],
        [-1.2781],
        [-0.0134]])
tensor(-1.1764)
tensor([0.2359, 1.1323])
tensor([ 0.2678, -1.5047])
tensor([[[-1.1603,  0.0037,  0.5752, -1.8553, -0.2216],
         [-1.2801,  1.0434, -0.2138,  0.9126,  0.3099],
         [-1.175

In [52]:
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

tensor(-1.5571)
tensor([ 2.1526, -0.2416])
tensor([-0.5817, -0.2862])
tensor([[[ 1.7491,  1.2164, -1.5942, -0.9981, -0.6127],
         [-1.0252,  1.6232,  1.0279,  0.1608,  0.1901],
         [ 0.3066,  0.8032,  0.0299, -0.4769,  0.4482],
         [-0.1270,  0.1544, -0.9237, -0.8893,  1.2569]],

        [[ 1.3002,  1.2701, -1.4480,  0.9792,  0.2713],
         [-0.2358, -0.2328, -0.3218,  0.6783, -0.0649],
         [ 1.3138, -1.1190,  0.4195, -0.5550, -0.5623],
         [-0.1518,  0.9876, -1.1545, -0.0756,  0.1612]],

        [[-0.2436,  0.5915, -0.7203,  1.2710, -0.5661],
         [-2.2154, -0.2767,  1.5700, -0.5976,  0.3555],
         [-0.5456, -2.2799,  0.8896, -1.5686, -0.0903],
         [ 0.7392,  1.5704, -0.2646,  0.2411, -0.1154]]])
tensor([[-0.8905],
        [-1.5685],
        [-0.8145]])
Trace Shapes:            
 Param Sites:            
Sample Sites:            
       a dist       |    
        value       |    
     log_prob       |    
       b dist       | 2  
        valu

### Subsampling tensors inside a plate

- To subsample data, you need to inform Pyro of both the original data size and the subsample size; Pyro will then choose a random subset of data and yield the set of indices.

In [53]:
data = torch.arange(100.)
print(data)

def model2():
    mean = pyro.param("mean", torch.zeros(len(data)))
    with pyro.plate("data", len(data), subsample_size=10) as ind:
        assert len(ind) == 10    # ind is a LongTensor that indexes the subsample.
        batch = data[ind]        # Select a minibatch of data.
        mean_batch = mean[ind]   # Take care to select the relevant per-datum parameters.
        # Do stuff with batch:
        x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
        assert len(x) == 10

test_model(model2, guide=lambda: None, loss=Trace_ELBO())

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
        42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
        56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.,
        70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83.,
        84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97.,
        98., 99.])


### Broadcasting to allow parallel enumeration

- To use parallel enumeration, Pyro needs to allocate tensor dimension that it can use for enumeration. To avoid conflicting with other dimensions that we want to use for plates, we need to declare a budget of the maximum number of tensor dimensions we’ll use. 

- This budget is called max_plate_nesting and is an argument to SVI (the argument is simply passed through to TraceEnum_ELBO). Usually Pyro can determine this budget on its own (it runs the (model,guide) pair once and record what happens), but in case of dynamic model structure you may need to declare max_plate_nesting manually.

To understand max_plate_nesting and how Pyro allocates dimensions for enumeration, let’s revisit model1() from above. This time we’ll map out three types of dimensions: enumeration dimensions on the left (Pyro takes control of these), batch dimensions in the middle, and event dimensions on the right.