# How Pyro Really Works

## or How to Write Yourself a Probabilistic Programming Framwork

Ever wondered how pyro really works? How does it do variational inference, and what is that in the first place? It's conceptually simpler than one might think, because it follows from the combination of a handful of rather simple ideas. In this tutorial we will look at each of them in turn.

## 0. Joint, Marginal, and Conditional Distributions

### Basics

Probability distributions may be defined on several random variables (RVs), e.g. $a$, $b$, and $c$. The distribution over all variables $p(a,b,c)$ is called the *joint distribution*, which is normalized such that the sum over all possible RV assignments is 1. (In the continuous case, we integrate instead of summing.)

The *marginal distribution* of a subset of the RVs is obtained from the joint distribution by summing (or integrating) over all remaining variables. for example
$$ p(a) = \sum_b \sum_c p(a,b,c) $$
and
$$ p(a,b) = \sum_c p(a,b,c) $$

A *conditional distribution* is the distribution of some of the RVs given that other RVs are known to have a certain value, e.g. $p(a,b \mid c=1)$.
If the value of the known is arbitrary, we write $p(a,b \mid c)$.
If we are only interested in the distribution of some of the unknown RVs, we can again marginalize out the remaining unknown RVs, e.g.
$$ p(a \mid c) = \sum_b p(a, b | c) $$

We can obtain the conditional distribution from the joint distribution by fixing the value of the known RVs ($p(a,b,c=1)$).
Since we now only look at a subset of the RV assignment (e.g. those in which $c=1$), we have to renormalize by the probability to have one of these assignments $p(c=1)$:
$$ p(a,b \mid c=1) = \dfrac{p(a,b,c=1)}{p(c=1)}$$
or more generally
$$ p(a,b \mid c) = \dfrac{p(a,b,c)}{p(c)} $$

### Factorizing a Joint Distribution

From the above section it follows that every joint distribution can be factorized into simpler conditional and marginal distributions.
For example, we can turn the previous equation around and obtain
$$ p(a,b,c) = p(a,b \mid c) p(c) $$
Since $p(a,b \mid c)$ can be understood as another joint distribution (of $a$ and $b$, given $c$), it can be factorized again (given $c$!), e.g.:
$$ p(a,b,c) = p(a,b \mid c) p(c) = p(a \mid b,c) p(b \mid c) p(c). $$
Note that the order in which we factor out variables doesn't matter, so the following is equivalent:
$$ p(a,b,c) = p(b, c \mid a) p(a) = p(b \mid c,a) p(c \mid a) p(a). $$

We can understand any factorization as an ordering of the RVs when sampling from the joint distribution.
Instead of sampling all of them together, we sample each RV from its conditional distribution, starting with the one that is sampled from its unconditional marginal distribution.
For example, in the first factorization above ($p(a \mid b,c) p(b \mid c) p(c)$, we first sample $c$ from $p(c)$, which doesn't depend on any of the other RVs.
Then, knowing the value of $c$, we can sample $b$ from $p(b \mid c)$, and finally (knowing $b$ and $c$) we can sample $a$ from $p(a \mid b,c)$.
Thus we have turned the joint distribution $p(a,b,c)$ into a *generative process* that produces each RV in turn.

## 1. Bayesian Inference



In [1]:
import torch

In [5]:
x = torch.tensor(1.)

In [19]:
def f(x):
    return 0.5 * x**2 + 1.

In [20]:
y = f(x)
y

tensor(1.5000)

In [25]:
x.requires_grad = True
y = f(x)
y.backward()

In [28]:
x.grad

tensor(1.)

In [37]:
# gradient descent
lr = 0.1
x = torch.tensor(1., requires_grad=True)
for i in range(100):
    y = f(x)
    y.backward()
    print('gradient = ', x.grad)
    x.data -= lr * x.grad
    print('x = ', x.data)
    x.grad.data.zero_()

gradient =  tensor(1.)
x =  tensor(0.9000)
gradient =  tensor(0.9000)
x =  tensor(0.8100)
gradient =  tensor(0.8100)
x =  tensor(0.7290)
gradient =  tensor(0.7290)
x =  tensor(0.6561)
gradient =  tensor(0.6561)
x =  tensor(0.5905)
gradient =  tensor(0.5905)
x =  tensor(0.5314)
gradient =  tensor(0.5314)
x =  tensor(0.4783)
gradient =  tensor(0.4783)
x =  tensor(0.4305)
gradient =  tensor(0.4305)
x =  tensor(0.3874)
gradient =  tensor(0.3874)
x =  tensor(0.3487)
gradient =  tensor(0.3487)
x =  tensor(0.3138)
gradient =  tensor(0.3138)
x =  tensor(0.2824)
gradient =  tensor(0.2824)
x =  tensor(0.2542)
gradient =  tensor(0.2542)
x =  tensor(0.2288)
gradient =  tensor(0.2288)
x =  tensor(0.2059)
gradient =  tensor(0.2059)
x =  tensor(0.1853)
gradient =  tensor(0.1853)
x =  tensor(0.1668)
gradient =  tensor(0.1668)
x =  tensor(0.1501)
gradient =  tensor(0.1501)
x =  tensor(0.1351)
gradient =  tensor(0.1351)
x =  tensor(0.1216)
gradient =  tensor(0.1216)
x =  tensor(0.1094)
gradient =  tenso

