In [1]:
!date

Sun Mar  6 21:34:28 PST 2022


In [2]:
import altair
import beanmachine.ppl as bm
import numpy as np
import pandas as pd
import torch.distributions as dist
from torch import tensor

In [3]:
altair.renderers.enable("mimetype")

RendererRegistry.enable('mimetype')

In [4]:
reproduction_rate_rate = 10.0

$$
f(x; \lambda) = 
\begin{cases}
\lambda e^{-\lambda x}, & x \ge 0 \\
0, & x \lt 0
\end{cases}
$$

In [5]:
@bm.random_variable
def reproduction_rate():
    return dist.Exponential(rate=reproduction_rate_rate)

In [6]:
pointer = reproduction_rate()
assert isinstance(pointer, bm.RVIdentifier)

In [7]:
@bm.random_variable
def num_new(num_current: int):
    return dist.Poisson(reproduction_rate() * num_current)

In [8]:
num_init = 1_087_980

observations = {num_new(num_init): tensor(238154)}

In [9]:
samples = bm.CompositionalInference().infer(
    queries=[reproduction_rate()],
    observations=observations,
    num_samples=7000,
    num_adaptive_samples=3000,
)

Samples collected: 100%|██████████| 10000/10000 [00:06<00:00, 1604.30it/s]
Samples collected: 100%|██████████| 10000/10000 [00:08<00:00, 1204.21it/s]
Samples collected: 100%|██████████| 10000/10000 [00:09<00:00, 1016.99it/s]
Samples collected: 100%|██████████| 10000/10000 [00:07<00:00, 1418.04it/s]


In [10]:
type(samples)

beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples

In [11]:
list(samples.keys())

[RVIdentifier(wrapper=<function reproduction_rate at 0x107a88dc0>, arguments=())]

In [12]:
reproduction_rate_samples = samples[reproduction_rate()][0]

In [13]:
reproduction_rate_samples

tensor([0.2185, 0.2196, 0.2186,  ..., 0.2185, 0.2188, 0.2193])

In [14]:
reproduction_rate_samples.shape

torch.Size([7000])

In [15]:
h, edges = np.histogram(reproduction_rate_samples, bins="auto", density=True)

In [16]:
altair.Chart(
    pd.DataFrame({"x1": edges[:-1], "x2": edges[1:], "density": h})
).mark_bar().encode(
    x=altair.X("x1", title="Reproduction rate"),
    x2="x2",
    y=altair.Y("density", title="Probability density")
).interactive()


<VegaLite 4 object>

If you see this message, it means the renderer has not been properly enabled
for the frontend that you are using. For more information, see
https://altair-viz.github.io/user_guide/troubleshooting.html


In [17]:
@bm.random_variable
def foo(i):
    return dist.Normal(0., 1.)

In [18]:
foo(0)

RVIdentifier(wrapper=<function foo at 0x11ee5b0d0>, arguments=(0,))

In [19]:
foo(1)

RVIdentifier(wrapper=<function foo at 0x11ee5b0d0>, arguments=(1,))