# SVI Part II: Conditional Independence, Subsampling, and Amortization

**This Tutorial is adapted from [https://pyro.ai/examples/svi_part_ii.html](https://pyro.ai/examples/svi_part_ii.html)*

## The Goal: Scaling SVI to Large Datasets

For a model with $N$ observations, running the **model** and **guide** and constructing the ELBO involves evaluating log pdf's whose complexity scales badly with $N$. This is a problem if we want to scale to large datasets. Luckily, the ELBO objective naturally supports subsampling provided that our model/guide have some conditional independence structure that we can take advantage of. For example, in the case that the observations are conditionally independent given the latents, the log likelihood term in the ELBO can be approximated with

$\sum_{i=1}^N \log p({\bf x}_i | {\bf z}) \approx  \frac{N}{M} \sum_{i\in{\mathcal{I}_M}} \log p({\bf x}_i | {\bf z})$

where $\mathcal{I}_M$ is a mini-batch of indices of size $M$ with $M < N$. So how do we do this in Pyro?

## Marking Conditional Independence in Pyro

If we want to do this sort of thing in Pyro, we first need to make sure that the **model** and **guide** are written in such a way that Pyro can leverage the relevant conditional independencies. Let's see how this is done. Pyro provides two language primitives for marking conditional independencies: plate and markov. Let's start with the simpler of the two.

### Sequential [plate]

Let's return to the example we used in the previous tutorial. For convenience let's replicate the main logic of model here:

In [1]:
def model(data):
    
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    
    # loop over the observed data using pyro.sample with the obs keyword argument
    for i in range(len(data)):
        
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

For this model the observations are conditionally independent given the latent random variable **latent_fairness**. **To explicitly mark this in Pyro we basically just need to replace the Python builtin range with the Pyro construct plate**:

In [2]:
def model(data):
    
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    
    # loop over the observed data [WE ONLY CHANGE THE NEXT LINE]
    for i in pyro.plate("data_loop", len(data)):
        
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

We see that **pyro.plate** is very similar to **range** with one main difference: each invocation of plate requires the user to provide a unique name. The second argument is an integer just like for range.

So far so good. Pyro can now leverage the conditional independency of the observations given the latent random variable. But how does this actually work? Basically **pyro.plate** is implemented using a context manager. At every execution of the body of the for loop we enter a new (conditional) independence context which is then exited at the end of the for loop body. Let's be very explicit about this:

- because each observed **pyro.sample** statement occurs within a different execution of the body of the for loop, Pyro marks each observation as independent
- this independence is properly a conditional independence given **latent_fairness** because **latent_fairness** is sampled outside of the context of **data_loop**.


Before moving on, let's mention some gotchas to be avoided when using sequential plate. Consider the following variant of the above code snippet:

In [None]:
# WARNING do not do this!
my_reified_list = list(pyro.plate("data_loop", len(data)))

for i in my_reified_list:
    
    pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

This will not achieve the desired behavior, since **list()** will enter and exit the **data_loop** context completely before a single **pyro.sample** statement is called. Similarly, we need to take care not to leak mutable computations across the boundary of the context manager, as this may lead to subtle bugs. For example, **pyro.plate** is not appropriate for temporal models where each iteration of a loop depends on the previous iteration; in this case a range or **pyro.markov** should be used instead.

### Vectorized [plate]

Conceptually vectorized plate is the same as sequential plate except that it is a vectorized operation (as **torch.arange** is to **range**). As such it potentially enables large speed-ups compared to the explicit for loop that appears with sequential plate. Let's see how this looks for our running example. First we need data to be in the form of a tensor:

In [4]:
import torch
import pyro

data      = torch.zeros(10)
data[0:6] = torch.ones(6)  # 6 heads and 4 tails

In [None]:
with pyro.plate('observe_data'):
    
    pyro.sample('obs', pyro.distributions.Bernoulli(f), obs=data)

Let's compare this to the analogous sequential plate usage point-by-point: 

- both patterns requires the user to specify a unique name. 
- note that this code snippet only introduces a single (observed) random variable (namely obs), since the entire tensor is considered at once. 
- since there is no need for an iterator in this case, there is no need to specify the length of the tensor(s) involved in the plate context.

Note that the gotchas mentioned in the case of **sequential plate** also apply to **vectorized plate**.