In [1]:
!date

Sun Mar  6 21:39:00 PST 2022


In [2]:
import altair
import beanmachine.ppl as bm
import numpy as np
import pandas as pd
import torch.distributions as dist
from torch import tensor

In [3]:
reproduction_rate_rate = 10.0

$$
f(x; \lambda) = 
\begin{cases}
\lambda e^{-\lambda x}, & x \ge 0 \\
0, & x \lt 0
\end{cases}
$$

In [4]:
@bm.random_variable
def reproduction_rate():
    return dist.Exponential(rate=reproduction_rate_rate)

In [5]:
pointer = reproduction_rate()
assert isinstance(pointer, bm.RVIdentifier)

In [6]:
@bm.random_variable
def num_new(num_current: int):
    return dist.Poisson(reproduction_rate() * num_current)

In [7]:
num_init = 1_087_980

observations = {num_new(num_init): tensor(238154)}

In [8]:
samples = bm.CompositionalInference().infer(
    queries=[reproduction_rate()],
    observations=observations,
    num_samples=7000,
    num_adaptive_samples=3000,
)

Samples collected: 100%|██████████| 10000/10000 [00:08<00:00, 1144.52it/s]
Samples collected: 100%|██████████| 10000/10000 [00:06<00:00, 1448.04it/s]
Samples collected: 100%|██████████| 10000/10000 [00:05<00:00, 1740.60it/s]
Samples collected: 100%|██████████| 10000/10000 [00:06<00:00, 1549.43it/s]


In [9]:
type(samples)

beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples

In [10]:
list(samples.keys())

[RVIdentifier(wrapper=<function reproduction_rate at 0x1269afca0>, arguments=())]

In [11]:
reproduction_rate_samples = samples[reproduction_rate()][0]

In [12]:
reproduction_rate_samples

tensor([0.2187, 0.2193, 0.2193,  ..., 0.2185, 0.2195, 0.2190])

In [13]:
reproduction_rate_samples.shape

torch.Size([7000])

In [14]:
h, edges = np.histogram(reproduction_rate_samples, bins="auto", density=True)

In [15]:
altair.Chart(
    pd.DataFrame({"x1": edges[:-1], "x2": edges[1:], "density": h})
).mark_bar().encode(
    x=altair.X("x1", title="Reproduction rate"),
    x2="x2",
    y=altair.Y("density", title="Probability density")
).interactive()


In [16]:
@bm.random_variable
def foo(i):
    return dist.Normal(0., 1.)

In [17]:
foo(0)

RVIdentifier(wrapper=<function foo at 0x1269caf70>, arguments=(0,))

In [18]:
foo(1)

RVIdentifier(wrapper=<function foo at 0x1269caf70>, arguments=(1,))