Skip to content

Commit

Permalink
Add goal for beta-binomial observation model
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 29, 2022
1 parent 0a8cc00 commit f1df8cb
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 2 deletions.
57 changes: 57 additions & 0 deletions aemcmc/conjugates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import aesara.tensor as at
from etuples import etuple, etuplize
from kanren import eq, lall
from unification import var


def beta_binomial_conjugateo(model_expr, observation_expr, posterior_expr):
r"""Produce a goal that represents the application of Bayes theorem
for a beta prior with a binomial observation model.
.. math::
\begin{align*}
p &\sim \operatorname{Beta}\left(\alpha, \beta\right)\\
y &\sim \operatorname{Binomial}\left(n, p\right)
\end{align*}
If we observe :math:`y=Y`, then :math:`p` follows a beta distribution:
.. math::
p \sim \operatorname{Beta}\left(\alpha + Y, \beta + n - Y\right)
"""

# Beta-binomial observation model
alpha_lv, beta_lv = var(), var()
p_rng_lv = var()
p_size_lv = var()
p_type_idx_lv = var()
p_et = etuple(
etuplize(at.random.beta), p_rng_lv, p_size_lv, p_type_idx_lv, alpha_lv, beta_lv
)
n_lv = var()
Y_et = etuple(etuplize(at.random.binomial), var(), var(), var(), n_lv, p_et)

y_lv = var() # observation

# Posterior distribution for p
new_alpha_et = etuple(etuplize(at.add), alpha_lv, y_lv)
new_beta_et = etuple(
etuplize(at.sub), etuple(etuplize(at.add), beta_lv, n_lv), y_lv
)
p_posterior_et = etuple(
etuplize(at.random.beta),
new_alpha_et,
new_beta_et,
rng=p_rng_lv,
size=p_size_lv,
dtype=p_type_idx_lv,
)

return lall(
eq(model_expr, Y_et),
eq(observation_expr, y_lv),
eq(posterior_expr, p_posterior_et),
)
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
maintainer_email="aesara-devs@gmail.com",
packages=["aemcmc"],
install_requires=[
"numpy>=1.18.1",
"scipy>=1.4.0",
"aesara",
"aeppl",
"etuples",
"logical-unification",
"miniKanren",
"numpy>=1.18.1",
"scipy>=1.4.0",
],
tests_require=["pytest"],
long_description=open("README.md").read() if exists("README.md") else "",
Expand Down
59 changes: 59 additions & 0 deletions tests/test_conjugates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import aesara
import aesara.tensor as at
import pytest
from aesara.graph.unify import eval_if_etuple
from aesara.tensor.random import RandomStream
from kanren import run
from unification import var

from aemcmc.conjugates import beta_binomial_conjugateo


def test_beta_binomial_conjugate_contract():
"""Produce the closed-form posterior for the binomial observation model with
a beta prior.
"""
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
p_rv = srng.beta(alpha_tt, beta_tt, name="p")

n_tt = at.iscalar("n")
Y_rv = srng.binomial(n_tt, p_rv)
y_vv = Y_rv.clone()
y_vv.tag.name = "y"

q_lv = var()
(posterior_expr,) = run(1, q_lv, beta_binomial_conjugateo(Y_rv, y_vv, q_lv))
posterior = eval_if_etuple(posterior_expr)
aesara.dprint(posterior)

assert isinstance(posterior.owner.op, type(at.random.beta))

# Build the sampling function and check the results on limiting cases.
sample_fn = aesara.function((alpha_tt, beta_tt, y_vv, n_tt), posterior)
assert sample_fn(1.0, 1.0, 1000, 1000) == pytest.approx(
1.0, abs=0.001
) # only successes
assert sample_fn(1.0, 1.0, 0, 1000) == pytest.approx(0.0, abs=0.001) # zero success


@pytest.mark.xfail
def test_beta_binomial_conjugate_expand():
"""Expand a contracted beta-binomial observation model."""

srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
y_vv = at.iscalar("y")
n_tt = at.iscalar("n")
Y_rv = srng.beta(alpha_tt + y_vv, beta_tt + n_tt - y_vv)

e_lv = var()
(expanded_expr,) = run(1, e_lv, beta_binomial_conjugateo(e_lv, y_vv, Y_rv))
expanded = eval_if_etuple(expanded_expr)

assert isinstance(expanded.owner.op, type(at.random.beta))

0 comments on commit f1df8cb

Please sign in to comment.