# SVI Part 1: An introduction to Stochastic Variational Inference
This is taken from [this](https://pyro.ai/examples/svi_part_i.html) page from the pyro examples section
Miguel Fuentes
Created: 4/29/2020
Last Updated: 4/30/2020

In [1]:
import math
import os
from tqdm import tqdm

import torch
import torch.distributions.constraints as constraints

import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

pyro.set_rng_seed(101)

## Setup
We can perform SVI on more or less arbitrary stochastic functions with Pyro. Besides the inputs the main components of a pyro model are as follows:
1) Observations (included with pyro.sample using obs keyword)  
2) Latent variables (included with pyro.sample)  
3) Paramaters (included with pyro.param)   

Every set of paramaters defines a joint probability over the observations and the latent variables. To perform SVI we need these assumptions about the joint pdfs defined by the paramaters:  
- We can sample from the pdfs
- We can compute the pointwise log pdf at any point
- The pdf is differentiable w.r.t. the paramaters

## What exactly are we trying to learn?
We want to find the most likely paramaters for our model, this can be rewritten as 
$$\theta_{max} = \underset{\theta}{\text{argmax}} log(p_{\theta}(x))$$  
To compute this quantity we must integrate over the latent variables. Doing this is often intractible, and even if we can do it we usually end up with a really hard non-convex optimization problem.  
Additionally, we also want to compute posteriors for the latent variables once we have the most likely paramaters. This requires another challenging computation:  
$$p_{\theta_{max}}(z|x) = \frac{p_{\theta_{max}}(x,z)}{\int d\textbf{z}p_{\theta_{max}}(x,z)}$$  
We don't want to, of often can't, do these calculations so we need a better way. Variational inference gives us a scheme to calculate $\theta_{max}$ and getting an approximate estimate for $p_{\theta_{max}}(z|x)$ . For this we need a few things, one of the most important is a guide.

## Guide
The idea here is to introduce a family of distributions paramaterizes by $\phi$, $q_{\phi}(z)$ over the latent variables. We will search this distribution space and try to find the best possible approximation of $p_{\theta_{max}}(z|x)$. In the literature qe call $\phi$ the variational paramaters and we call $q_{\phi}(z)$ the variational distribution. In pyro, this is called the guide because that is shorter and easier to remember.  
We will define our guide function the same way we would define any other model in pyro. However, ince we need the guide to produce a joint distribution over the latent variables, we need to impose some constraints:  
1) The model and guide should have the same call signature (args and kwargs)  
2) The guide should not include any observations  
3) Any latent variable which is appears in the model (with a pyro.sample call) must also appear in the guide  
Once we have defined the guide we can go on to search the distribution space for the best posterior approximation. To do this we need an objective function.

## ELBO
The ELBO (evidence lower bound) is the objective function we are going to optimize. The reason we choose to optimize this is that we know that maximizing the ELBO will result in minimizing the KL divergence between $q_{\phi}(z)$ and $p_{\theta}(z|x)$ and this is exactly the goal of Variational Inference. Maximizing ELBO minimizes KL divergence as a result of this identity:  
$$logp_{\theta}(x) − ELBO =KL(q_{\phi}(z)||p_{\theta}(z|x))$$  
For a fixed θ, as we take steps in ϕ space that increase the ELBO, we decrease the KL divergence between the guide and the posterior, i.e. we move the guide towards the posterior. In the general case we take gradient steps in both θ and ϕ space simultaneously so that the guide and model play chase, with the guide tracking a moving posterior $logp_{\theta}(z|x)$. Perhaps somewhat surprisingly, despite the moving target, this optimization problem can be solved (to a suitable level of approximation) for many different problems.  
So at high level variational inference is easy: all we need to do is define a guide and compute gradients of the ELBO. Actually, computing gradients for general model and guide pairs leads to some complications (see the tutorial SVI Part III for a discussion). For the purposes of this tutorial, let’s consider that a solved problem and look at the support that Pyro provides for doing variational inference.

## SVI Class
In Pyro the machinery for doing variational inference is encapsulated in the SVI class.
The user needs to provide three things: the model, the guide, and an optimizer. We’ve discussed the model and guide above and we’ll discuss the optimizer in some detail below, so let’s assume we have all three ingredients at hand.  
The SVI object provides two methods, step() and evaluate_loss(), that encapsulate the logic for variational learning and evaluation:  
 - The method step() takes a single gradient step and returns an estimate of the loss (i.e. minus the ELBO). If provided, the arguments to step() are piped to model() and guide().  
 - The method evaluate_loss() returns an estimate of the loss without taking a gradient step. Just like for step(), if provided, arguments to evaluate_loss() are piped to model() and guide().  
For the case where the loss is the ELBO, both methods also accept an optional argument num_particles, which denotes the number of samples used to compute the loss (in the case of evaluate_loss) and the loss and gradient (in the case of step).

## Optimizers
In Pyro, the model and guide are allowed to be arbitrary stochastic functions provided that:  
 - guide doesn’t contain pyro.sample statements with the obs argument  
 - model and guide have the same call signature  

This presents some challenges because it means that different executions of model() and guide() may have quite different behavior, with e.g. certain latent random variables and parameters only appearing some of the time. Indeed parameters may be created dynamically during the course of inference. In other words the space we’re doing optimization over, which is parameterized by θ and ϕ, can grow and change dynamically.  
In order to support this behavior, Pyro needs to dynamically generate an optimizer for each parameter the first time it appears during learning. Luckily, PyTorch has a lightweight optimization library (see torch.optim) that can easily be repurposed for the dynamic case.  
All of this is controlled by the optim.PyroOptim class, which is basically a thin wrapper around PyTorch optimizers. PyroOptim takes two arguments: a constructor for PyTorch optimizers optim_constructor and a specification of the optimizer arguments optim_args. At high level, in the course of optimization, whenever a new parameter is seen optim_constructor is used to instantiate a new optimizer of the given type with arguments given by optim_args.  

Most users will probably not interact with PyroOptim directly and will instead interact with the aliases defined in optim/__init__.py. There are two ways to specify the optimizer arguments. In the simpler case, optim_args is a fixed dictionary that specifies the arguments used to instantiate PyTorch optimizers for all the parameters.  
The second way to specify the arguments allows for a finer level of control. Here the user must specify a callable that will be invoked by Pyro upon creation of an optimizer for a newly seen parameter. This callable must have the following signature:  
 - module_name: the Pyro name of the module containing the parameter, if any
 - param_name: the Pyro name of the parameter  
This gives the user the ability to, for example, customize learning rates for different parameters. For an example where this sort of level of control is useful, see the discussion of baselines.

## Example: Determing Coin Fairness
For a simple example we'll assume you are given a coin, you want to determine what the probability that the coin will land on heads is. You have a prior distribution over the fairness defined by Beta(10, 10), then you observe some data and want to update the belief about the coin fairness based on the data. First, we generate the data:

In [2]:
# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.0))

The next step is to model the coin flips based on our prior belief and make observations on the data. We will also define our guide now. Notice the following things about the guide and the model:  
 - We’ve taken care that the names of the random variables line up exactly between the model and guide.
 - model(data) and guide(data) take the same arguments.
 - The variational parameters are torch.tensors. The requires_grad flag is automatically set to True by pyro.param.
 - We use constraint=constraints.positive to ensure that alpha_q and beta_q remain non-negative during optimization.

In [3]:
# clear the param store in case we're in a REPL
pyro.clear_param_store()

def model(data):
    # define the hyperparameters that control the beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    # register the two variational parameters with Pyro
    # - both parameters will have initial value 15.0.
    # - because we invoke constraints.positive, the optimizer
    # will take gradients on the unconstrained parameters
    # (which are related to the constrained parameters by a log)
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

Now we can perform the inference. Note that in the step() method we pass in the data, which then get passed to the model and guide.

In [4]:
# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 1500
# do gradient steps
for step in tqdm(range(n_steps)):
    svi.step(data)

# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nbased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))

100%|█████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:05<00:00, 297.51it/s]


based on the data and our prior belief, the fairness of the coin is 0.535 +- 0.090





This estimate is to be compared to the exact posterior mean, which in this case is given by 16/30=0.53. Note that the final estimate of the fairness of the coin is in between the the fairness preferred by the prior (namely 0.50) and the fairness suggested by the raw empirical frequencies (6/10=0.60).