Skip to content

Commit

Permalink
Add cumsum logprob
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Oct 28, 2021
1 parent 3de67d1 commit eb01b9c
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 0 deletions.
1 change: 1 addition & 0 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

# isort: off
# Add optimizations to the DBs
import aeppl.cumsum
import aeppl.mixture
import aeppl.scan

Expand Down
84 changes: 84 additions & 0 deletions aeppl/cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import List, Optional

import aesara.tensor as at
from aesara.graph.opt import local_optimizer
from aesara.tensor.extra_ops import CumOp

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import _logprob, logprob
from aeppl.opt import PreserveRVMappings, rv_sinking_db


class MeasurableCumsum(CumOp):
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""


MeasurableVariable.register(MeasurableCumsum)


@_logprob.register(MeasurableCumsum)
def logprob_cumsum(op, values, base_rv, **kwargs):
"""Compute the log-likelihood graph for a `Cumsum`."""
(value,) = values

value_diff = at.diff(value, axis=op.axis)
value_diff = at.concatenate(
(
# Take first element of axis and add a broadcastable dimension so that
# it can be concatentaed with the rest of value_diff
at.shape_padaxis(
at.take(value, 0, axis=op.axis),
axis=op.axis,
),
value_diff,
),
axis=op.axis,
)

cumsum_logp = logprob(base_rv, value_diff)

return cumsum_logp


@local_optimizer([CumOp])
def find_measurable_cumsums(fgraph, node) -> Optional[List[MeasurableCumsum]]:
r"""Finds `Cumsums`\s for which a `logprob` can be computed."""

if not (isinstance(node.op, CumOp) and node.op.mode == "add"):
return None

if isinstance(node.op, MeasurableCumsum):
return None

rv_map_feature: PreserveRVMappings = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

rv = node.outputs[0]

if rv not in rv_map_feature.rv_values:
return None # pragma: no cover

base_rv = node.inputs[0]
if not (
base_rv.owner
and isinstance(base_rv.owner.op, MeasurableVariable)
and base_rv not in rv_map_feature.rv_values
):
return None # pragma: no cover

# Check that cumsum does not mix dimensions
if base_rv.ndim > 1 and node.op.axis is None:
return None

new_op = MeasurableCumsum(axis=node.op.axis or 0, mode="add")
# Make base_var unmeasurable
unmeasurable_base_rv = assign_custom_measurable_outputs(base_rv.owner)
new_rv = new_op.make_node(unmeasurable_base_rv).default_output()
new_rv.name = rv.name

return [new_rv]


rv_sinking_db.register("find_measurable_cumsums", find_measurable_cumsums, -5, "basic")
82 changes: 82 additions & 0 deletions tests/test_cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import aesara
import aesara.tensor as at
import numpy as np
import pytest
import scipy.stats as st

from aeppl import joint_logprob
from tests.utils import assert_no_rvs


@pytest.mark.parametrize(
"size, axis",
[
(10, None),
(10, 0),
((2, 10), 0),
((2, 10), 1),
((3, 2, 10), 0),
((3, 2, 10), 1),
((3, 2, 10), 2),
],
)
def test_normal_cumsum(size, axis):
rv = at.random.normal(0, 1, size=size).cumsum(axis)
vv = rv.clone()
logp = joint_logprob({rv: vv})
assert_no_rvs(logp)

assert np.isclose(
st.norm(0, 1).logpdf(np.ones(size)).sum(),
logp.eval({vv: np.ones(size).cumsum(axis)}),
)


@pytest.mark.parametrize(
"size, axis",
[
(10, None),
(10, 0),
((2, 10), 0),
((2, 10), 1),
((3, 2, 10), 0),
((3, 2, 10), 1),
((3, 2, 10), 2),
],
)
def test_bernoulli_cumsum(size, axis):
rv = at.random.bernoulli(0.9, size=size).cumsum(axis)
vv = rv.clone()
logp = joint_logprob({rv: vv})
assert_no_rvs(logp)

assert np.isclose(
st.bernoulli(0.9).logpmf(np.ones(size)).sum(),
logp.eval({vv: np.ones(size, int).cumsum(axis)}),
)


def test_destructive_cumsum_fails():
"""Test that a cumsum that mixes dimensions fails"""
x_rv = at.random.normal(size=(2, 2, 2)).cumsum()
x_vv = x_rv.clone()
with pytest.raises(KeyError):
joint_logprob({x_rv: x_vv})


def test_deterministic_cumsum():
"""Test that deterministic cumsum is not affected"""
x_rv = at.random.normal(1, 1, size=5)
cumsum_x_rv = at.cumsum(x_rv)
y_rv = at.random.normal(cumsum_x_rv, 1)

x_vv = x_rv.clone()
y_vv = y_rv.clone()
logp = joint_logprob({x_rv: x_vv, y_rv: y_vv})
assert_no_rvs(logp)

logp_fn = aesara.function([x_vv, y_vv], logp)
assert np.isclose(
logp_fn(np.ones(5), np.arange(5) + 1),
st.norm(1, 1).logpdf(1) * 10,
)

0 comments on commit eb01b9c

Please sign in to comment.