# Mini Pyro Explanation

In [0]:
from __future__ import absolute_import, division, print_function

from collections import OrderedDict
import weakref

import torch

## Pre example for weakref package

In [0]:
# example comparing dict and WeakValueDictionary
class C: pass
ci=C()
print(ci)

wvd = weakref.WeakValueDictionary({'key' : ci})
print(dict(wvd), len(wvd)) #1
del ci
print(dict(wvd), len(wvd)) #0

ci2=C()
d=dict()
d['key']=ci2
print(d, len(d))
del ci2
print(d, len(d))

<__main__.C object at 0x7fa3a53382b0>
{'key': <__main__.C object at 0x7fa3a53382b0>} 1
{} 0
{'key': <__main__.C object at 0x7fa3a5338278>} 1
{'key': <__main__.C object at 0x7fa3a5338278>} 1


## Pre example for **args* and **kwargs 
- Usage of *args
  - They are mostly used in function definitions. **args* and **kwargs allow you to pass a variable number of arguments to a function. What variable means here is that you do not know beforehand how many arguments can be passed to your function by the user.
- Usage of **kwargs
  - ***kwargs allows you to pass keyworded variable length of arguments to a function. You should use **kwargs if you want to handle named arguments in a function. 

In [0]:
def test_var_args(f_arg, *argv):
    print("first normal arg:", f_arg)
    for arg in argv:
        print("another arg through *argv:", arg)
test_var_args('yasoob', 'python', 'eggs', 'test')

first normal arg: yasoob
another arg through *argv: python
another arg through *argv: eggs
another arg through *argv: test


In [0]:
def greet_me(**kwargs):
    for key, value in kwargs.items():
        print("{0} = {1}".format(key, value))
greet_me(name="yasoob")

name = yasoob


In [0]:
# combination
def test_args_kwargs(arg1, arg2, arg3):
    print("arg1:", arg1)
    print("arg2:", arg2)
    print("arg3:", arg3)

# first with *args
args = ("two", 3, 5)
test_args_kwargs(*args)

# now with **kwargs:
kwargs = {"arg3": 3, "arg2": "two", "arg1": 5}
test_args_kwargs(**kwargs)

# some_func(fargs, *args, **kwargs)

arg1: two
arg2: 3
arg3: 5
arg1: 5
arg2: two
arg3: 3


## Pre example for super() function in self-defined class

In [0]:
class Rectangle:
    def __init__(self, length, width):
        self.length = length
        self.width = width

    def area(self):
        return self.length * self.width

    def perimeter(self):
        return 2 * self.length + 2 * self.width

class Square(Rectangle):
    def __init__(self, length):
        # The first parameter refers to the subclass Square, 
        # while the second parameter refers to a Square object which, 
        # in this case, is self
        super(Square, self).__init__(length, length)
        
# used super() to call the __init__() of the Rectangle class, 
# allowing you to use it in the Square class without repeating code
# Rectangle is the superclass, and Square is the subclass
square = Square(4)
square.area()

16

## Initialize Pyro_stack and Param_store first

In [0]:
# Pyro keeps track of two kinds of global state:
# i)  The effect handler stack, which enables non-standard interpretations of
#     Pyro primitives like sample();
#     See http://docs.pyro.ai/en/0.3.1/poutine.html
# ii) Trainable parameters in the Pyro ParamStore;
#     See http://docs.pyro.ai/en/0.3.1/parameters.html

# Handlers earlier in the PYRO_STACK are applied first.
PYRO_STACK = []
PARAM_STORE = {}  # maps name -> (unconstrained_value, constraint)

def get_param_store():
    return PARAM_STORE

## Messenger -- the very basic effect handler class in Pyro stack
-  A generic Messenger actually contains two methods that are called once per operation where side effects are performed: 
  - 1. _process_message modifies a message and sends the result to the Messenger just above on the stack.
  - 2. _postprocess_message modifies a message and sends the result to the next Messenger down on the stack. It is always called after all active Messengers have had their _process_message method applied to the message.

- Although custom Messengers can override _process_message and _postprocess_message, it’s convenient to avoid requiring all effect handlers to be aware of all possible effectful operation types. For this reason, by default Messenger._process_message will use msg["type"] to dispatch to a corresponding method Messenger._pyro_<type>, e.g. Messenger._pyro_sample as in LogJointMessenger. Just as exception handling code ignores unhandled exception types, this allows Messengers to simply forward operations they don’t know how to handle up to the next Messenger in the stack:
- The order in which Messengers are applied to an operation like a pyro.sample statement is determined by the order in which their _ _enter__ methods are called. Messenger._ _enter__ appends a Messenger to the end (the bottom) of the global handler stack

In [0]:
# An example for dispatching corresponding method
class Messenger(object):
    ...
    def _process_message(self, msg):
        method_name = "_pyro_{}".format(msg["type"])  # e.g. _pyro_sample when msg["type"] == "sample"
        if hasattr(self, method_name):
            getattr(self, method_name)(msg)
    ...

In [0]:
# The base effect handler class (called Messenger here for consistency with Pyro).
class Messenger(object):
    def __init__(self, fn=None):
        self.fn = fn

    # Effect handlers push themselves onto the PYRO_STACK.
    # Handlers earlier in the PYRO_STACK are applied first.
    
    # Messenger.__enter__ appends a Messenger to the end (the bottom) of the global handler stack
    def __enter__(self):
        PYRO_STACK.append(self)
        
    # __exit__ removes a Messenger from the stack
    # if the last messenger in the stack is itself then pop
    def __exit__(self, *args, **kwargs):
        assert PYRO_STACK[-1] is self
        # remove the last element
        PYRO_STACK.pop()

    def process_message(self, msg):
        pass
    def postprocess_message(self, msg):
        pass

    def __call__(self, *args, **kwargs):
        with self:
            return self.fn(*args, **kwargs)

## Some Messenger examples to show
- _ _enter__ and _ _exit__ are special methods needed by any Python context manager. 
- When implementing new Messenger classes, if we override _ _enter__ and _ _exit__, we always need to call the base Messenger’s _ _enter__ and _ _exit__ methods for the new Messenger to be applied correctly.

### 1. Trace Messenger - record the message info to a dictionary

In [0]:
# A first useful example of an effect handler.
# trace records the inputs and outputs of any primitive site it encloses,
# and returns a dictionary containing that data to the user.

# trace class here can be considered as a kind of messenger
# it inherits Messenger's enter function
# define its own postprocess_message by storing info into a dictionary
class trace(Messenger):
    def __enter__(self):
        super(trace, self).__enter__()
        self.trace = OrderedDict()
        return self.trace

    # trace illustrates why we need postprocess_message in addition to process_message:
    # We only want to record a value after all other effects have been applied
    def postprocess_message(self, msg):
        assert msg["name"] not in self.trace, "all sites must have unique names"
        self.trace[msg["name"]] = msg.copy()

    def get_trace(self, *args, **kwargs):
        self(*args, **kwargs)
        return self.trace

### 2. Replay Messenger - replace message value according to a trace dictionary

In [0]:
# A second example of an effect handler for setting the value at a sample site.
# This illustrates why effect handlers are a useful PPL implementation technique:
# We can compose trace and replay to replace values but preserve distributions,
# allowing us to compute the joint probability density of samples under a model.
# See the definition of elbo(...) below for an example of this pattern.
class replay(Messenger):
    # guide_trace here is a dictionary
    def __init__(self, fn, guide_trace):
        self.guide_trace = guide_trace
        super(replay, self).__init__(fn)

    def process_message(self, msg):
        # replace message value
        if msg["name"] in self.guide_trace:
            msg["value"] = self.guide_trace[msg["name"]]["value"]

### 3. Block Messenger - message would NOT be operated  by further handlers

In [0]:
# block allows the selective application of effect handlers to different parts of a model.
# Sites hidden by block will only have the handlers below block on the PYRO_STACK applied,
# allowing inference or other effectful computations to be nested inside models.
class block(Messenger):
    def __init__(self, fn=None, hide_fn=lambda msg: True):
        self.hide_fn = hide_fn
        super(block, self).__init__(fn)

    def process_message(self, msg):
        if self.hide_fn(msg):
            msg["stop"] = True

### 4. Plate Messenger - generate conditionally independent samples

In [0]:
# This limited implementation of PlateMessenger only implements broadcasting.
class PlateMessenger(Messenger):
    def __init__(self, fn, size, dim):
        assert dim < 0
        self.size = size
        self.dim = dim
        super(PlateMessenger, self).__init__(fn)

    def process_message(self, msg):
        if msg["type"] == "sample":
            batch_shape = msg["fn"].batch_shape
            if len(batch_shape) < -self.dim or batch_shape[self.dim] != self.size:
                batch_shape = [1] * (-self.dim - len(batch_shape)) + list(batch_shape)
                batch_shape[self.dim] = self.size
                msg["fn"] = msg["fn"].expand(torch.Size(batch_shape))

    def __iter__(self):
        return range(self.size)

# boilerplate to match the syntax of actual pyro.plate:
# Construct for conditionally independent sequences of variables
def plate(name, size, dim):
    return PlateMessenger(fn=None, size=size, dim=dim)

##  Define sample and param function 
- If no active Messengers, we just draw a sample and return it as expected, otherwise initialize a message and call apply_stack function to send it to Messengers

In [0]:
# sample is an effectful version of Distribution.sample(...)
# When any effect handlers are active, it constructs an initial message and calls apply_stack.
# fn - distribution function for samples
def sample(name, fn, obs=None):
    # if there are no active Messengers, we just draw a sample and return it as expected:
    if not PYRO_STACK:
        return fn()
    # Otherwise, we initialize a message...
    initial_msg = {
        "type": "sample",
        "name": name,
        "fn": fn,
        "args": (),
        "value": obs,
    }
    # ...and use apply_stack to send it to the Messengers
    msg = apply_stack(initial_msg)
    return msg["value"]

In [0]:
# param is an effectful version of PARAM_STORE.setdefault that also handles constraints.
# When any effect handlers are active, it constructs an initial message and calls apply_stack.
def param(name, init_value=None, constraint=torch.distributions.constraints.real):

    def fn(init_value, constraint):
        # if exist already in param store then call directly
        if name in PARAM_STORE:
            unconstrained_value, constraint = PARAM_STORE[name]
        else:
            # Initialize with a constrained value.
            assert init_value is not None
            # "with torch.no_grad()" temporarily set all the requires_grad flag to false
            with torch.no_grad():
                # .detach() is to detach a tensor from the network graph, making the tensor no gradient
                constrained_value = init_value.detach()
                # The transform_to() registry is useful for performing unconstrained optimization 
                # on constrained parameters of probability distributions,
                unconstrained_value = torch.distributions.transform_to(constraint).inv(constrained_value)
            unconstrained_value.requires_grad_()
            PARAM_STORE[name] = unconstrained_value, constraint

        # Transform from unconstrained space to constrained space.
        constrained_value = torch.distributions.transform_to(constraint)(unconstrained_value)
        constrained_value.unconstrained = weakref.ref(unconstrained_value)
        return constrained_value

    # if there are no active Messengers, we just draw a sample and return it as expected:
    if not PYRO_STACK:
        return fn(init_value, constraint)
    # Otherwise, we initialize a message...
    initial_msg = {
        "type": "param",
        "name": name,
        "fn": fn,
        "args": (init_value, constraint),
        "value": None,
    }

    # ...and use apply_stack to send it to the Messengers
    msg = apply_stack(initial_msg)
    return msg["value"]

## Full initial message 
- The actual messages sent up and down the stack are dictionaries with a particular set of keys.Write out the full initial message here for completeness:

In [0]:
msg = {
    # The following fields contain the name, inputs, function, and output of a site.
    # These are generally the only fields you'll need to think about.
    "name": "x",
    "fn": dist.Bernoulli(0.5),
    "value": None,  # msg["value"] will eventually contain the value returned by pyro.sample
    "is_observed": False,  # because obs=None by default; only used by sample sites
    "args": (),  # positional arguments passed to "fn" when it is called; usually empty for sample sites
    "kwargs": {},  # keyword arguments passed to "fn" when it is called; usually empty for sample sites
    # This field typically contains metadata needed or stored by a particular inference algorithm
    "infer": {"enumerate": "parallel"},
    # The remaining fields are generally only used by Pyro's internals,
    # or for implementing more advanced effects beyond the scope of this tutorial
    "type": "sample",  # label used by Messenger._process_message to dispatch, in this case to _pyro_sample
    "done": False,
    "stop": False,
    "scale": torch.tensor(1.),  # Multiplicative scale factor that can be applied to each site's log_prob
    "mask": None,
    "continuation": None,
    "cond_indep_stack": (),  # Will contain metadata from each pyro.plate enclosing this sample site.
}

## apply_stack function - transfer message to stack for operations
-  traverses the stack twice at each operation:
  - first from bottom to top to apply each _process_message 
  - and then from top to bottom to apply each _postprocess_message

In [0]:
# apply_stack is called by pyro.sample and pyro.param.
# It is responsible for applying each Messenger to each effectful operation.
def apply_stack(msg):
    for pointer, handler in enumerate(reversed(PYRO_STACK)):
        handler.process_message(msg)
        # When a Messenger sets the "stop" field of a message,
        # it prevents any Messengers above it on the stack from being applied.
        if msg.get("stop"):
            break
    if msg["value"] is None:
        # use args to run function and get message values
        msg["value"] = msg["fn"](*msg["args"])

    # A Messenger that sets msg["stop"] == True also prevents application
    # of postprocess_message by Messengers above it on the stack
    # via the pointer variable from the process_message loop
    for handler in PYRO_STACK[-pointer-1:]:
        handler.postprocess_message(msg)
    return msg

## Adam optimizer class
- It dynamically generates optimizers for dynamically generated parameters

In [0]:
# This is a thin wrapper around the `torch.optim.Adam` class that
# dynamically generates optimizers for dynamically generated parameters.
# See http://docs.pyro.ai/en/0.3.1/optimization.html
class Adam(object):
    def __init__(self, optim_args):
        self.optim_args = optim_args
        # Each parameter will get its own optimizer, which we keep track
        # of using this dictionary keyed on parameters.
        self.optim_objs = {}

    def __call__(self, params):
        for param in params:
            # If we've seen this parameter before, use the previously
            # constructed optimizer.
            if param in self.optim_objs:
                optim = self.optim_objs[param]
            # If we've never seen this parameter before, construct
            # an Adam optimizer and keep track of it.
            else:
                optim = torch.optim.Adam([param], **self.optim_args)
                self.optim_objs[param] = optim
            # Take a gradient step for the parameter param.
            optim.step()

## SVI example - contain Messenger Trace and Block to store parameters values only
- To be more specific: Here Trace Messenger comes first in the stack, then Block. Block here is put at the bottom but operates first (stack property) in the operating process. process__function sent the messege to the Messenger above(Trace) here then msg['stop'] == True thus stops. While if Trace Messenger works here, then it stores the msg info from Block Messenger successfully

In [0]:
# This is a unified interface for stochastic variational inference in Pyro.
# The actual construction of the loss is taken care of by `loss`.
# See http://docs.pyro.ai/en/0.3.1/inference_algos.html
class SVI(object):
    def __init__(self, model, guide, optim, loss):
        self.model = model
        self.guide = guide
        self.optim = optim
        self.loss = loss

    # This method handles running the model and guide, constructing the loss
    # function, and taking a gradient step.
    def step(self, *args, **kwargs):
        # This wraps both the call to `model` and `guide` in a `trace` so that
        # we can record all the parameters that are encountered. Note that
        # further tracing occurs inside of `loss`.
        with trace() as param_capture:
            # We use block here to allow tracing to record parameters only.
            with block(hide_fn=lambda msg: msg["type"] == "sample"):
                loss = self.loss(self.model, self.guide, *args, **kwargs)
        # Differentiate the loss.
        loss.backward()
        # Grab all the parameters from the trace.
        params = [site["value"].unconstrained()
                  for site in param_capture.values()]
        # Take a step w.r.t. each parameter in params.
        self.optim(params)
        # Zero out the gradients so that they don't accumulate.
        for p in params:
            p.grad = p.new_zeros(p.shape)
        return loss.item()

## ELBO calculation - using Trace Messenger to record and run the model
- We’ve defined a Pyro model with observations x and latents z of the form pθ(x,z)=pθ(x|z)pθ(z). We’ve also defined a Pyro guide (i.e. a variational distribution) of the form qϕ(z). Here θ and ϕ are variational parameters for the model and guide, respectively. (In particular these are not random variables that call for a Bayesian treatment).
- We’d like to maximize the log evidence logpθ(x) by maximizing the ELBO (the evidence lower bound) given by ELBO≡Eqϕ(z)[logpθ(x,z)−logqϕ(z)]

In [0]:
# one model example
def model(data):
    # define the hyperparameters that control the beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
        
# one guide example
def guide(data):
    # register the two variational parameters with Pyro.
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))
    
# by running replay(model, guide_trace), latent_fairness would be 
# sampled from guide using alpha_q and beta_q

In [0]:
# This is a basic implementation of the Evidence Lower Bound, which is the
# fundamental objective in Variational Inference.
# See http://pyro.ai/examples/svi_part_i.html for details.
# This implementation has various limitations (for example it only supports
# random variables with reparameterized samplers), but all the ELBO
# implementations in Pyro share the same basic logic.
def elbo(model, guide, *args, **kwargs):
    # Run the guide with the arguments passed to SVI.step() and trace the execution,
    # i.e. record all the calls to Pyro primitives like sample() and param().
    guide_trace = trace(guide).get_trace(*args, **kwargs)
    # Now run the model with the same arguments and trace the execution. Because
    # model is being run with replay, whenever we encounter a sample site in the
    # model, instead of sampling from the corresponding distribution in the model,
    # we instead reuse the corresponding sample from the guide. In probabilistic
    # terms, this means our loss is constructed as an expectation w.r.t. the joint
    # distribution defined by the guide.
    model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs)
    # We will accumulate the various terms of the ELBO in `elbo`.
    elbo = 0.
    # Loop over all the sample sites in the model and add the corresponding
    # log p(z) term to the ELBO. Note that this will also include any observed
    # data, i.e. sample sites with the keyword `obs=...`.
    for site in model_trace.values():
        if site["type"] == "sample":
            elbo = elbo + site["fn"].log_prob(site["value"]).sum()
    # Loop over all the sample sites in the guide and add the corresponding
    # -log q(z) term to the ELBO.
    for site in guide_trace.values():
        if site["type"] == "sample":
            elbo = elbo - site["fn"].log_prob(site["value"]).sum()
    # Return (-elbo) since by convention we do gradient descent on a loss and
    # the ELBO is a lower bound that needs to be maximized.
    return -elbo


# This is a wrapper for compatibility with full Pyro.
def Trace_ELBO(*args, **kwargs):
    return elbo

## A final Example to show all above

In [3]:
!pip install pyro-ppl

Collecting pyro-ppl
[?25l  Downloading https://files.pythonhosted.org/packages/c0/e1/d67bf6252b9a0a1034bfd81c23fd28cdb8078670187f60084c1785bcae42/pyro-ppl-0.3.3.tar.gz (231kB)
[K     |████████████████████████████████| 235kB 5.0MB/s 
Collecting opt_einsum>=2.3.2 (from pyro-ppl)
[?25l  Downloading https://files.pythonhosted.org/packages/f6/d6/44792ec668bcda7d91913c75237314e688f70415ab2acd7172c845f0b24f/opt_einsum-2.3.2.tar.gz (59kB)
[K     |████████████████████████████████| 61kB 18.3MB/s 
Collecting tqdm>=4.31 (from pyro-ppl)
[?25l  Downloading https://files.pythonhosted.org/packages/9f/3d/7a6b68b631d2ab54975f3a4863f3c4e9b26445353264ef01f465dc9b0208/tqdm-4.32.2-py2.py3-none-any.whl (50kB)
[K     |████████████████████████████████| 51kB 17.5MB/s 
[?25hBuilding wheels for collected packages: pyro-ppl, opt-einsum
  Building wheel for pyro-ppl (setup.py) ... [?25l[?25hdone
  Stored in directory: /root/.cache/pip/wheels/37/6b/8b/8d15c6042ed38db155158baf56c1949a6e12d5d709697b0c37
  Bui

### Pre-example for Pyro.plate
- Each invocation of plate requires the user to provide a unique name. The second argument is an integer just like for range.
- Pyro can now leverage the conditional independency of the observations given the latent random variable. Basically pyro.plate is implemented using a context manager. At every execution of the body of the for loop we enter a new (conditional) independence context which is then exited at the end of the for loop body. 
  - because each observed pyro.sample statement occurs within a different execution of the body of the for loop, Pyro marks each observation as independent.
  - this independence is properly a conditional independence given latent_fairness because latent_fairness is sampled outside of the context of data_loop.

In [0]:
# Comparison between range() and pyro.plate
def model(data):
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data using pyro.sample with the obs keyword argument
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
        
def model(data):
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data [WE ONLY CHANGE THE NEXT LINE]
    for i in pyro.plate("data_loop", len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

# subsample minibatches of data
with plate("data", len(data), subsample_size=100) as ind:
  batch = data[ind]
  assert len(batch) == 100

In [12]:
"""
This example demonstrates the functionality of `pyro.contrib.minipyro`,
which is a minimal implementation of the Pyro Probabilistic Programming
Language that was created for didactic purposes.
"""

from __future__ import absolute_import, division, print_function
import argparse
import torch
import pyro

# We use the pyro.generic interface to support dynamic choice of backend.
from pyro.generic import pyro_backend
from pyro.generic import distributions as dist
from pyro.generic import infer, optim, pyro


def main(args):
    # Define a basic model with a single Normal latent random variable `loc`
    # and a batch of Normally distributed observations.
    def model(data):
        loc = pyro.sample("loc", dist.Normal(0., 1.))
        with pyro.plate("data", len(data), dim=-1):
            pyro.sample("obs", dist.Normal(loc, 1.), obs=data)

    # Define a guide (i.e. variational distribution) with a Normal
    # distribution over the latent random variable `loc`.
    def guide(data):
        guide_loc = pyro.param("guide_loc", torch.tensor(0.))
        guide_scale = pyro.param("guide_scale_log", torch.tensor(0.)).exp()
        pyro.sample("loc", dist.Normal(guide_loc, guide_scale))

    # Generate some data.
    torch.manual_seed(0)
    data = torch.randn(100) + 3.0

    # Because the API in minipyro matches that of Pyro proper,
    # training code works with generic Pyro implementations.
    with pyro_backend(args["backend"]):
        # Construct an SVI object so we can do variational inference on our
        # model/guide pair.
        elbo = infer.Trace_ELBO()
        adam = optim.Adam({"lr": args["learning_rate"]})
        svi = infer.SVI(model, guide, adam, elbo)

        # Basic training loop
        pyro.get_param_store().clear()
        for step in range(args["num_steps"]):
            loss = svi.step(data)
            if step % 100 == 0:
                print("step {} loss = {}".format(step, loss))

        # Report the final values of the variational parameters
        # in the guide after training.
        for name in pyro.get_param_store():
            value = pyro.param(name)
            print("{} = {}".format(name, value.detach().cpu().numpy()))

        # For this simple (conjugate) model we know the exact posterior. In
        # particular we know that the variational distribution should be
        # centered near 3.0. So let's check this explicitly.
        assert (pyro.param("guide_loc") - 3.0).abs() < 0.1

args = {"num_steps": 1001, "learning_rate": 0.02, "backend": "minipyro"}
main(args)

# if __name__ == "__main__":
#     assert pyro.__version__.startswith('0.3.3')
#     parser = argparse.ArgumentParser(description="Mini Pyro demo")
#     parser.add_argument("-b", "--backend", default="minipyro")
#     parser.add_argument("-n", "--num-steps", default=1001, type=int)
#     parser.add_argument("-lr", "--learning-rate", default=0.02, type=float)
#     args = parser.parse_args()
#     main(args)

step 0 loss = 291.2471618652344
step 100 loss = 164.15792846679688
step 200 loss = 149.47970581054688
step 300 loss = 150.03028869628906
step 400 loss = 165.29713439941406
step 500 loss = 153.3885955810547
step 600 loss = 164.81736755371094
step 700 loss = 150.8622589111328
step 800 loss = 150.74578857421875
step 900 loss = 150.77191162109375
step 1000 loss = 152.4605712890625
guide_loc = 3.0301661491394043
guide_scale_log = -1.7844144105911255
