In [1]:
!date

Sun Feb 20 16:26:10 PST 2022


In [2]:
import altair
import beanmachine.ppl as bm
import numpy as np
import pandas as pd
import torch.distributions as dist
from torch import tensor

In [4]:
reproduction_rate_rate = 10.0

$$
f(x; \lambda) = 
\begin{cases}
\lambda e^{-\lambda x}, & x \ge 0 \\
0, & x \lt 0
\end{cases}
$$

In [5]:
@bm.random_variable
def reproduction_rate():
    return dist.Exponential(rate=reproduction_rate_rate)

In [28]:
pointer = reproduction_rate()
assert isinstance(pointer, bm.RVIdentifier)

In [6]:
@bm.random_variable
def num_new(num_current: int):
    return dist.Poisson(reproduction_rate() * num_current)

In [8]:
num_init = 1_087_980

observations = {num_new(num_init): tensor(238154)}

In [10]:
samples = bm.CompositionalInference().infer(
    queries=[reproduction_rate()],
    observations=observations,
    num_samples=7000,
    num_adaptive_samples=3000,
)

Samples collected: 100%|██████████| 10000/10000 [00:05<00:00, 1911.07it/s]
Samples collected: 100%|██████████| 10000/10000 [00:05<00:00, 1968.43it/s]
Samples collected: 100%|██████████| 10000/10000 [00:04<00:00, 2173.07it/s]
Samples collected: 100%|██████████| 10000/10000 [00:04<00:00, 2065.62it/s]


In [12]:
type(samples)

beanmachine.ppl.inference.monte_carlo_samples.MonteCarloSamples

In [14]:
list(samples.keys())

[RVIdentifier(wrapper=<function reproduction_rate at 0x120aeb310>, arguments=())]

In [15]:
reproduction_rate_samples = samples[reproduction_rate()][0]

In [16]:
reproduction_rate_samples

tensor([0.2185, 0.2187, 0.2193,  ..., 0.2187, 0.2190, 0.2184])

In [26]:
reproduction_rate_samples.shape

torch.Size([7000])

In [18]:
h, edges = np.histogram(reproduction_rate_samples, bins="auto", density=True)

In [25]:
altair.Chart(
    pd.DataFrame({"x1": edges[:-1], "x2": edges[1:], "density": h})
).mark_bar().encode(
    x=altair.X("x1", title="Reproduction rate"),
    x2="x2",
    y=altair.Y("density", title="Probability density")
).interactive()


In [37]:
@bm.random_variable
def foo(i):
    return dist.Normal(0., 1.)

In [38]:
foo(0)

RVIdentifier(wrapper=<function foo at 0x121e3b5e0>, arguments=(0,))

In [39]:
foo(1)

RVIdentifier(wrapper=<function foo at 0x121e3b5e0>, arguments=(1,))