In [None]:
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
pyro.enable_validation(True)

import torch as T
import torch.optim as O
import torch.distributions.constraints as constraints
import numpy as np
from tqdm import tqdm

from bagoftools.plotting import stem_hist

import scipy.stats as stats

import random

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set()

In [None]:
def f(n, i, j, do_print=True):
    
    def _f():
        m = 2**n
        x = pyro.sample("x", dist.Categorical(T.ones(m)))
        pyro.sample("obs", dist.Bernoulli(1.0),
                    obs=((x == i) | (x == j)).float())
        if do_print:
            print('model x = {}'.format(x))
        return x
    
    return _f

In [None]:
pyro.clear_param_store()

prog = f(6, 1, 10, do_print=False)
posterior = pyro.infer.Importance(prog, num_samples=10000)
marginal = pyro.infer.EmpiricalMarginal(posterior.run(), sites=['x'])

x = [marginal().item() for _ in range(1_000)]
stem_hist(x)

In [None]:
model = f(3, 0, 1, do_print=True)
guide = f(3, 0, 1, do_print=True)

elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"));

In [None]:
elbo.enumerate_support

---

In [None]:
params = lambda: dict(pyro.get_param_store().items())

def hist(xs, bins=64, xlim=None):
    plt.figure(figsize=(14,4))
    _ = plt.hist(xs, bins=bins, density=True, color='c')
    
    if xlim is not None:
        plt.xlim(xlim)

## Basics

In [None]:
d = dist.Normal(0, 1)
p = lambda x: d.log_prob(T.tensor(x)).exp().item()

p(0)

## Conditioning

In [None]:
def scale(guess):
    weight = pyro.sample('weight', dist.Normal(guess, 1.0))
    return pyro.sample('measurement', dist.Normal(weight, 0.75))

conditioned_scale = pyro.condition(scale, data={'measurement': 9.5})

In [None]:
# equivalent to conditioned_scale above
def scale_obs(guess):
    weight = pyro.sample('weight', dist.Normal(guess, 1.))
     # here we condition on measurement == 9.5
    return pyro.sample('measurement', dist.Normal(weight, 0.75), obs=9.5)

In [None]:
# both will be equal to the obs value of `measurement`, for any input `guess`
conditioned_scale(guess=0), scale_obs(guess=0)

In [None]:
@pyro.condition(data={'z': 1.0})
def gauss():
    z = pyro.sample('z', dist.Normal(0, 1))
    return z

gauss()

## Models and Inference
**Q:** how are vars involved / inter-related in a model (since they have a name)?

**A:** they are stored in `ParamStore` and are used during optimization to model distribution paramters.

---

- `weather()` specifies a joint probability distribution over two named random variables: `cloudy` and `temp`
- it defines a probabilistic model that we can reason about
- e.g. if I observe a temperature of 70 degrees, how likely is it to be cloudy?

In [None]:
def weather():
    cloudy = pyro.sample('cloudy', dist.Bernoulli(0.3))
    cloudy = 'cloudy' if cloudy.item() == 1.0 else 'sunny'
    mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
    scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
    temp = pyro.sample('temp', dist.Normal(mean_temp, scale_temp))
    return cloudy, temp.item()

weather()

### Docs example

If `y` is observed to be 9.5,
then find `a,b` for `x ~ N(a, b)`

In [None]:
# Suppose we observe that the measurement of an object was 9.5 lbs. 
# What would have we guessed if we tried to guess it’s weight first?
# i.e. compute P(x | y = 9.5)

def model(loc):
    x = pyro.sample('x', dist.Normal(loc, 1.0))
    y = pyro.sample('y', dist.Normal(x, 0.75))
    return y

conditioned_model = pyro.condition(model, data={'y': T.tensor(9.5)})

def guide(loc):
    a = pyro.param('a', T.tensor(loc))
    b = pyro.param('b', T.tensor(1.0), constraint=constraints.positive)
    x = pyro.sample('x', dist.Normal(a, b))
    return x

pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_model,
                     guide=guide,
                     optim=pyro.optim.SGD({'lr': 0.001, 'momentum': 0.1}),
                     loss=pyro.infer.Trace_ELBO())

loc_prior = 8.5
losses, a, b  = [], [], []
num_steps = 5000
for t in tqdm(range(num_steps)):
    losses.append(svi.step(loc_prior))
    a.append(pyro.param('a').item())
    b.append(pyro.param('b').item())

print('a = ', pyro.param('a').item())
print('b = ', pyro.param('b').item())

plt.plot(losses)
plt.title('ELBO')
plt.xlabel('step')
plt.ylabel('loss');
plt.figure()
plt.plot(a)
plt.figure()
plt.plot(b)

In [None]:
a = pyro.param('a').item()
b = pyro.param('b').item()

def f_():
    x = pyro.sample('x', dist.Normal(a, b))
    y = pyro.sample('y', dist.Normal(x, 0.75))
    return y

out = [f_() for _ in range(50000)]
hist(out, bins=96)
plt.plot([9.5]*14, np.linspace(0, 0.7, 14), c='r', linewidth=3)
plt.plot([a]*14, np.linspace(0, 0.7, 14), c='m', linewidth=3)
pass

---

## Simulate `observe`?

### Rejector

In [None]:
class Wrapper(dist.Rejector):
    def __init__(self, underlying, log_accept, log_scale):    
        super(Wrapper, self).__init__(underlying, log_accept, log_scale)


underlying = dist.Normal(0, 1)

x0 = T.tensor(0.0)

# this actually implements the predicate
log_accept = lambda x: (x > x0).float().log()

# CDF of predicate valid area: P(X > x) = 1 - P(X <= x)
log_scale = 1 - underlying.cdf(x0).log()

w = Wrapper(underlying, log_accept, log_scale)
xs = [pyro.sample('x', w) for _ in range(10000)]

plt.figure(figsize=(14, 4))
_, bins, _ = plt.hist(xs, bins=32, density=True, color='c')
plt.xlim([-3, 3])

pass

### Importance sampling
- we want to find the posterior $\left( x | y \sim \text{Bernoulli}(1) \right)$
- Bernoulli's success rate controls the predicate compliance
- `model()` gives $P(y | x)$
- $P(x|y) = P(y|x) \cdot P(x) \ / \ P(y)$

#### Observations
- `Empirical` (therefore `marginal`) does not have a `cdf`
- support = samples from posterior

In [None]:
# Predicates
eq     = lambda a, b: float(np.isclose(a, b, atol=1e-3))
eq_any = lambda x, xs: T.tensor(any(map(lambda a: eq(a, x), xs))).float()
gt     = lambda x, y: (x > y).float()
lt     = lambda x, y: (x < y).float()

In [None]:
def model():
    """
    s = 1
    when (x <= 0) -> y: 1-s
    when (x  > 0) -> y: s
    
    p(x > 0) ~ Bernoulli(s)
    """
#     1. observe(x > 0)
    x = pyro.sample('x', dist.Normal(0, 1))
    y = pyro.sample('y', dist.Bernoulli(1.0), obs=gt(x, 0))
    return y
    
#     2. obtain z = x + y (choosing a good prior)
#     x = pyro.sample('x', dist.Normal(1,1))
#     y = pyro.sample('y', dist.Normal(1,1))
#     z = pyro.sample('z', dist.Normal(1.5,1), obs=(x+y))
#     return z

#     3. y < x
#     x = pyro.sample('x', dist.Uniform(0, 1))
#     y = pyro.sample('y', dist.Uniform(0, 1))
#     z = pyro.sample('z', dist.Bernoulli(1.0), obs=lt(y, x))
#     return z

N = 5000

# perform posterior inference by importance sampling
posterior = pyro.infer.Importance(model, num_samples=N)

# construct marginal distribution
marginal  = pyro.infer.EmpiricalMarginal(posterior.run(), sites=['x'])

samples = [marginal.sample() for _ in range(2*N)]

plt.figure(figsize=(10,4))
plt.hist(samples, range=[-4, 4], bins=64, color='c', label='marginal', density=True, stacked=True)
plt.legend()
plt.show()

In [None]:
xs = posterior.run()

for i, (tr, log_weight) in enumerate(zip(xs.exec_traces, xs.log_weights)):
    if i == 10: break
    x = tr.nodes['x']['value'].item()
    y = log_weight.exp().item()
    print(f'{x:8.5f}: {y:7.5f}')

$$
\begin{align}
P(x > 0) &= \frac{1}{\texttt{xs[xs > 0].size}} \\
P(x \leq 0) &= 0
\end{align}
$$

In [None]:
p = lambda x: marginal.log_prob(T.tensor([x])).exp()

xs = np.array([x for x in marginal.enumerate_support()])
ps = np.array([p(x) for x in xs])

p_ = 1/xs[xs>0].size
print(f'P(x > 0) = {p_:9.7f}')

plt.figure(figsize=(12, 6))
plt.scatter(xs, ps, s=1, c='c')
plt.plot([xs.min(), xs.max()], [p_] * 2, linewidth=1, c='b', alpha=0.5)

# step prob: 
# p(x  > 0) = 1 / xs[xs > 0].size 
# p(x <= 0) = 0

pass