# SVI Part I: Stochastic Variational Inference in Pyro

**This Tutorial is adapted from [https://pyro.ai/examples/svi_part_i.html](https://pyro.ai/examples/svi_part_i.html)*

Pyro has been designed with particular attention paid to supporting stochastic variational inference as a general purpose inference algorithm. In this tutorial, we will see how we go about doing variational inference in Pyro. In the two subsequent notebooks, we will apply what we learned in this tutorial to develop two widely used probabilistic models:

- **Variational autoencoders**.
- **Hidden Markov models**.

## Setup

We are going to assume we have already defined our model in Pyro. As a quick reminder, the model is given as a stochastic function model(\*args, \**kwargs), which, in the general case takes arguments. 

The different pieces of model() are encoded via the mapping:

1. observations ⟺ **pyro.sample** with the obs argument

2. latent random variables ⟺ **pyro.sample**

3. parameters ⟺ **pyro.param**

Now let's establish some notation. The model has *observations* $\boldsymbol{x}$ and *latent random variables* $\boldsymbol{z}$ as well as parameters $\theta$. It has the joint probability density:

$p_{\theta}(\boldsymbol{x}, \boldsymbol{z}) = p_{\theta}(\boldsymbol{x}\,|\,\boldsymbol{z})\cdot p_{\theta}(\boldsymbol{z})$

We assume that the various probability distributions that make up $p_{\theta}(\boldsymbol{x}, \boldsymbol{z})$ have the following properties:

1. We can sample from each distribution.
2. We can compute the pointwise log pdf of each distribution.
3. Each distribution is differentiable w.r.t. the parameters $\theta$.

## Model Learning

In this context our criterion for learning a good model will be **maximizing the log evidence**, i.e. we want to find the value of $\theta$ given by

$\theta_{max} = \arg \max_{\theta} p_{\theta}(\boldsymbol{x}).$

where the log evidence $\log(p_{\theta}(x))$ is given by

$\log(p_{\theta}(x)) = \log \int \, p_{\theta}(\boldsymbol{x},\boldsymbol{z})\, d\boldsymbol{z}.$

In the general case this is a doubly difficult problem. This is because (even for a fixed $\theta$) the integral over the latent random variables $\boldsymbol{z}$ is often intractable. Furthermore, even if we know how to calculate the log evidence for all values of $\theta$, maximizing the log evidence as a function of $\theta$ will in general be a difficult non-convex optimization problem.

In addition to finding $θ_{max}$, we would like to calculate the posterior over the latent variables $\boldsymbol{z}$:

$p_{\theta_{max}}(\boldsymbol{z}\,|\,\boldsymbol{x}) = \frac{p_{\theta_{max}}(\boldsymbol{x},\,\boldsymbol{z})}{\int p_{\theta_{max}}(\boldsymbol{x},\,\boldsymbol{z}) d\boldsymbol{z}}$

Note that the denominator of this expression is the (usually intractable) evidence. Variational inference offers a scheme for finding $\theta_{max}$ and computing an approximation to the posterior $p_{\theta_{max}}(\boldsymbol{z}\,|\,\boldsymbol{x})$. Let's see how that works.

## Guide

The basic idea is that we introduce a parameterized distribution $q_{\varphi}(\boldsymbol{z})$, where $\varphi$ are known as the **variational parameters**. This distribution is called the **variational distribution** in much of the literature, and in the context of Pyro it's called **the guide** (Guide = variational distribution). The guide will serve as an approximation to the posterior distribution $p_{\theta_{max}}(\boldsymbol{z}\,|\,\boldsymbol{x})$.

Just like the model, the guide is encoded as a stochastic function **guide()** that contains **pyro.sample** and **pyro.param** statements. It does not contain observed data, since the guide needs to be a properly normalized distribution. Note that Pyro enforces that **model()** and **guide()** have the same call signature, i.e. both callables should take the same arguments.

Since the guide is an approximation to the posterior $p_{\theta_{max}}(\boldsymbol{z}\,|\,\boldsymbol{x})$, the guide needs to provide a valid joint probability density over all the latent random variables in the model. Recall that when random variables are specified in Pyro with the primitive statement **pyro.sample()** the first argument denotes the name of the random variable. These names will be used to align the random variables in the model and guide. 

To be very explicit, if the model contains a random variable $z_1$

In [1]:
def model():
    
    pyro.sample("z_1", ...)

then the guide needs to have a matching sample statement

In [2]:
def guide():
    
    pyro.sample("z_1", ...)

**The distributions used in the two cases can be different, but the names must line-up 1-to-1.**

Once we've specified a guide (we give some explicit examples below), we're ready to proceed to inference. Learning will be setup as an optimization problem where each iteration of training takes a step in $\theta-\varphi$ space that moves the guide closer to the exact posterior. To do this we need to define an appropriate objective function.

## Evidence Lower Bound (ELBO) 