In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm
import seaborn as sns

import torch
import pyro
import pyro.distributions as dist
from torch import nn
from pyro.nn import PyroModule
from torch.distributions import constraints
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from pyro.infer import Predictive
import pyro.optim as optim
from IPython.display import Image

# Pyro vocabulary
---

## model

- `model` : any function is a model.
It is a "composition of primitive stochastic functions and deterministic computations".
$P(X| \theta)$
A stochastic function can be any Python object with a `__call__()` method, like a function, a method, or a PyTorch `nn.Module`.


``` python
def weather(temp_yesterday):
    temp_today = pyro.sample("temp_today", dist.Normal(temp_yesterday, 1.0))
    measurement = pyro.sample("measurement", dist.Normal(temp_today, 0.75))
    return measurement
```

NB: `measurement` is output, `temp_yesterday` is input, `temp_today` is latent
weather() yields a measurement $M | T_y$

<!-- or $P(M | T_t) P(T_t | T_y)$ -->

---
## sample
- `sample` : get a collection of independent rv from the same primitive

In [2]:
pyro.sample('now', lambda : [1]*5)

[1, 1, 1, 1, 1]

In [3]:
def hot(temperature):
    if temperature > 10:
        return lambda : True
    else:
        return lambda : False
        
pyro.sample('now', hot(12)), pyro.sample('now', hot(5))

(True, False)

- with pyro.distributions the shapes are broadcasted :

In [4]:
pyro.sample("normal_sample", dist.Normal(torch.ones(2,3), 1))

tensor([[1.0470, 1.0677, 0.0269],
        [1.3179, 0.0765, 1.0546]])

***
## Shapes
3(!) different shapes at play:

- sample_shape
- batch_shape 
- event_shape

pytorch tensor has only one shape attribute `.shape`:

In [5]:
ones = torch.ones(2,3)
ones.shape

torch.Size([2, 3])

pyro distribution has two shapes: `batch_shape` and `event_shape`:

(by the shapes are broadcasted in pyro distribution)

In [6]:
d_norm = dist.Normal(torch.ones(2,3), 1)
d_norm.batch_shape, d_norm.event_shape

(torch.Size([2, 3]), torch.Size([]))

In [7]:
d_mnorm = dist.MultivariateNormal(torch.ones(3), torch.eye(3))
d_mnorm.batch_shape, d_mnorm.event_shape

(torch.Size([]), torch.Size([3]))

Indices over `.batch_shape` denote conditionally independent random variables, 
whereas indices over `.event_shape` denote dependent random variables (i.e. one draw from a distribution). Because the dependent random variables define probability together, the `.log_prob()` method only produces a single number for each event of shape `.event_shape`. Thus the total shape of `.log_prob()` is .batch_shape:

In [8]:
d_norm.log_prob(torch.rand(2, 3)).shape, d_norm.log_prob(torch.rand(1)).shape

(torch.Size([2, 3]), torch.Size([2, 3]))

In [9]:
# d_norm.log_prob(torch.rand(2)).shape

In [10]:
d_mnorm.log_prob(torch.rand(1)).shape, d_mnorm.log_prob(torch.rand(3)).shape

(torch.Size([]), torch.Size([]))

- sample_shape
- batch_shape 
- event_shape

In [11]:
sample_shape = (8, 7)
current_dist = d_mnorm
data = current_dist.sample(sample_shape)
assert data.shape == sample_shape + current_dist.batch_shape + current_dist.event_shape
sample_shape, current_dist.batch_shape, current_dist.event_shape, data.shape

((8, 7), torch.Size([]), torch.Size([3]), torch.Size([8, 7, 3]))

In [12]:
sample_shape = (8, 7)
current_dist = d_norm
data = current_dist.sample(sample_shape)
assert data.shape == sample_shape + current_dist.batch_shape + current_dist.event_shape
sample_shape,  current_dist.batch_shape, current_dist.event_shape, data.shape

((8, 7), torch.Size([2, 3]), torch.Size([]), torch.Size([8, 7, 2, 3]))

---
## Condition

- `condition`: we can condition on 


``` python
def weather(temp_yesterday):
    temp_today = pyro.sample("temp_today", dist.Normal(temp_yesterday, 1.0))
    measurement = pyro.sample("measurement", dist.Normal(temp_today, 0.75))
    return measurement
```

$P(M | T_y)$ but in fact $P(M | T_t) P(T_t| T_y)$,

so we can $P(M | T_y, T_x = t)$

In [13]:
def weather(temp_yesterday):
    temp_today = pyro.sample("temp_today", dist.Normal(temp_yesterday, 1.0))
    measurement = pyro.sample("measurement", dist.Normal(temp_today, 1.0))
    return measurement

conditioned_weather = pyro.condition(weather, data={"temp_today": 10.})

weather(10), conditioned_weather(10.)

(tensor(6.9269), tensor(9.8178))

- `condition` statement can cast as `obs` argument to `sample`

``` python
def weather(temp_yesterday):
    temp_today = pyro.sample("temp_today", dist.Normal(temp_yesterday, 1.0))
    measurement = pyro.sample("measurement", dist.Normal(temp_today, 1.0), obs=5)
    return measurement
```

_  | iid | independent | dependent
--- | --- | --- | ---
*shape* | `sample_shape` | `batch_shape` | `event_shape`


## Inference:

Given $P(X| \theta_{obs}, \theta_{unobs})$ infer the posterior $P(\theta_{obs}, \theta_{unobs} | X)$.

Posterior in fact can be understood as an optimal $\varphi (\theta_{obs}, \theta_{unobs})$ from a certain class of functions, such that 

a certain loss function is minimized. Could be $\mathbb{E}_\varphi [\log P(X,  \theta_{obs} \theta_{unobs}) - \log\varphi (\theta_{obs}, \theta_{unobs})]$

## guide

- `guide` : is such an optimal approximation of the posterior $\varphi (\theta_{obs}, \theta_{unobs})$

quote :

Inference algorithms in Pyro, such as pyro.infer.SVI, allow us to use arbitrary stochastic functions, which we will call guide functions or guides, as approximate posterior distributions. Guide functions must satisfy these two criteria to be valid approximations for a particular model: 

- all unobserved (i.e., not conditioned) sample statements that appear in the model appear in the guide. 
- the guide has the same input signature as the model (i.e., takes the same arguments).

Guide functions can serve as programmable, data-dependent proposal distributions for importance sampling, rejection sampling, sequential Monte Carlo, MCMC, and independent Metropolis-Hastings, and as variational distributions or inference networks for stochastic variational inference. Currently, importance sampling, MCMC, and stochastic variational inference are implemented in Pyro, and we plan to add other algorithms in the future.

Although the precise meaning of the guide is different across different inference algorithms, the guide function should generally be chosen so that, in principle, it is flexible enough to closely approximate the distribution over all unobserved sample statements in the model.