Skip to content

Commit

Permalink
Implement a general mixture log-probability rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 1, 2021
1 parent 517fa3d commit 13ccc2c
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 1 deletion.
6 changes: 6 additions & 0 deletions aeppl/joint_logprob.py
Expand Up @@ -158,6 +158,12 @@ def joint_logprob(
node.op, q_rv_value_var, *value_var_inputs, name=q_rv_var.name, **kwargs
)

if q_logprob_var is None:
# When a `_logprob` returns `None` it signifies that the node
# can be skipped, because, for example, its log-likelihood is
# determined independently by its inputs.
continue

if q_rv_var.name:
q_logprob_var.name = f"{q_rv_var.name}_logprob"

Expand Down
212 changes: 211 additions & 1 deletion aeppl/opt.py
@@ -1,13 +1,20 @@
from typing import Dict
from typing import Dict, List, Optional, Tuple

import aesara
import aesara.tensor as at
import numpy as np
from aesara.compile.mode import optdb
from aesara.graph.basic import Apply, Constant
from aesara.graph.features import Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt import local_optimizer
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.tensor.basic import Join, MakeVector, ScalarFromTensor
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.basic import BernoulliRV, CategoricalRV
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import (
local_dimshuffle_rv_lift,
Expand All @@ -17,13 +24,18 @@
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
from aesara.tensor.var import TensorVariable

from aeppl.logprob import _logprob, logprob
from aeppl.utils import indices_from_subtensor

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
subtensor_ops = (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)


class PreserveRVMappings(Feature):
Expand Down Expand Up @@ -150,11 +162,209 @@ def naive_bcast_rv_lift(fgraph, node):
return [bcasted_node.outputs[1]]


class MixtureRV(RandomVariable):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""

@classmethod
def create_node(cls, node, indices, mixture_rvs):
out_var = node.default_output()

inputs = list(indices) + list(mixture_rvs)
mixture_op = MixtureRV(
f"{out_var.name if out_var.name else ''}-mixture",
out_var.ndim,
[inp.ndim for inp in inputs],
out_var.dtype,
inplace=False,
)

mixture_op.num_indices = len(indices)
mixture_op.idx_list = getattr(node.op, "idx_list", None)

new_node = mixture_op.make_node(None, None, None, *inputs)

return new_node

def _infer_shape(self, size, dist_params, param_shapes=None):
# Ignore the index inputs
return super()._infer_shape(
size, dist_params[self.num_indices :], param_shapes=param_shapes
)

def rng_fn(self, rng, *args, **kwargs):
raise NotImplementedError()


def get_stack_mixing_mixture_vars(
node: Apply,
) -> Optional[Tuple[int, List[TensorVariable], List[TensorVariable]]]:
r"""Extract the join_axis, mixing and mixture terms from a `*Subtensor*` applied to stacked `RandomVariable`\s."""
if not isinstance(node.op, subtensor_ops):
return

joined_rvs = node.inputs[0]

# First, make sure that it's some sort of combination
if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, (MakeVector, Join))):
# Node is not a compatible join `Op`
return None

if isinstance(joined_rvs.owner.op, MakeVector):
mixture_rvs = joined_rvs.owner.inputs
join_axis = mixing_axis_idx = 0

elif isinstance(joined_rvs.owner.op, Join):
mixture_rvs = joined_rvs.owner.inputs[1:]

join_axis = joined_rvs.owner.inputs[0]

# We need a constant axis value, so we try constant folding
axis_fg = FunctionGraph(
outputs=[join_axis], features=[ShapeFeature()], clone=True
)

folded_axis = optimize_graph(axis_fg, custom_opt=topo_constant_folding).outputs

if not isinstance(folded_axis, Constant):
# TODO: We could use `Scan` for all of this and cover the
# non-constant case.
return None

join_axis = int(folded_axis.data)
mixing_axis_idx = node.default_output().ndim - join_axis - 1

if not all(
rv.owner and isinstance(rv.owner.op, RandomVariable) for rv in mixture_rvs
):
# Currently, all mixture components must be `RandomVariable` outputs
# TODO: Allow constants and make them Dirac-deltas
return None

indices = indices_from_subtensor(
getattr(node.op, "idx_list", None), node.inputs[1:]
)

if len(indices) - 1 < mixing_axis_idx:
# No indexing occurs on the joined axis
return None

idx = indices[mixing_axis_idx]

if not idx.owner:
# TODO: This could be a `Constant`; we might want to support this.
return None

if isinstance(idx.owner.op, ScalarFromTensor):
# This covers the case of a single scalar index
idx = idx.owner.inputs[0]

if idx.owner and isinstance(idx.owner.op, (BernoulliRV, CategoricalRV)):
mixing_rv = idx
else:
return None

return join_axis, mixing_rv, mixture_rvs


@local_optimizer(subtensor_ops)
def mixture_replace(fgraph, node):
r"""Identify mixture sub-graphs and replace them with a place-holder `Op`.
The basic idea is to find ``stack(mixture_comps)[I_rv]``, where
``mixture_comps`` is a ``list`` of `RandomVariable`\s and ``I_rv`` is a
`RandomVariable` with a discrete and finite support.
From these terms, new terms ``Z_rv[i] = mixture_comps[i][i == I_rv]`` are
created for each ``i`` in ``enumerate(mixture_comps)``.
"""

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return

out_var = node.default_output()

if out_var not in rv_map_feature.rv_values:
return

# These should be a combination of `RandomVariable`s
mixture_res = get_stack_mixing_mixture_vars(node)

if mixture_res is None:
return

join_axis, mixing_rv, mixture_rvs = mixture_res

assert mixing_rv in rv_map_feature.rv_values

mixture_value_var = rv_map_feature.rv_values.pop(out_var, None)

# We loop through mixture components and collect all the array elements
# that belong to each one (by way of their indices).
for i, component_rv in enumerate(mixture_rvs):
if component_rv in rv_map_feature.rv_values:
raise ValueError("A value variable was specified for a mixture component")
component_rv.tag.ignore_logprob = True

# Replace this sub-graph with a `MixtureRV`
new_node = MixtureRV.create_node(node, node.inputs[1:], mixture_rvs)

new_mixture_rv = new_node.default_output()
new_mixture_rv.name = "mixture"
rv_map_feature.rv_values[new_mixture_rv] = mixture_value_var

# FIXME: This is pretty hackish
fgraph.import_node(new_node, import_missing=True, reason="mixture_rv")

return [new_mixture_rv]


@_logprob.register(MixtureRV)
def logprob_MixtureRV(op, value, *inputs, name=None, **kwargs):
_, _, _, *inputs = inputs

indices = inputs[: op.num_indices]
indices = indices_from_subtensor(op.idx_list, indices)
comp_rvs = inputs[op.num_indices :]

if len(indices) > 1:
raise NotImplementedError()

if value.ndim > 0:
logp_val = at.alloc(-np.inf, *tuple(at.shape(value)))

mixing_axis_idx = value.ndim - op.join_axis - 1

for i, comp_rv in enumerate(comp_rvs):

selected_mask = at.eq(indices[mixing_axis_idx], i)

new_indices = list(indices)
new_indices[mixing_axis_idx] = at.nonzero(selected_mask)

obs_i = value[tuple(new_indices)]

bcasted_comp_rv = at.broadcast_to(comp_rv, value.shape)
indexed_comp_rv = bcasted_comp_rv[new_indices]

logp_val = at.set_subtensor(
logp_val[new_indices], logprob(indexed_comp_rv, obs_i)
)
else:
logp_val = 0.0
for i, comp_rv in enumerate(comp_rvs):
selected_mask = at.eq(indices[0], i)
logp_val += logprob(comp_rv, value) * selected_mask

return logp_val


logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.register("canonicalize", optdb["canonicalize"], -10, "basic")
rv_sinking_db = EquilibriumDB()
rv_sinking_db.register("dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic")
rv_sinking_db.register("subtensor_lift", local_subtensor_rv_lift, -5, "basic")
rv_sinking_db.register("broadcast_to_lift", naive_bcast_rv_lift, -5, "basic")
rv_sinking_db.register("incsubtensor_lift", incsubtensor_rv_replace, -5, "basic")
rv_sinking_db.register("mixture_replace", mixture_replace, -5, "basic")
logprob_rewrites_db.register("sinking", rv_sinking_db, -10, "basic")
83 changes: 83 additions & 0 deletions tests/test_joint_logprob.py
Expand Up @@ -198,3 +198,86 @@ def test_ignore_logprob():
logp_exp = joint_logprob(y_rv_2, {y_rv_2: y})

assert equal_computations([logp], [logp_exp])


@aesara.config.change_flags(compute_test_value="raise")
def test_hetero_mixture_scalar():

X_rv = at.random.normal(0, 1, name="X")
Y_rv = at.random.gamma(0.5, 0.5, name="Y")

p_at = at.scalar("p")
p_at.tag.test_value = 0.8

I_rv = at.random.bernoulli(p_at, name="I")
i_vv = I_rv.clone()
i_vv.name = "i"

M_rv = at.stack([X_rv, Y_rv])[I_rv]
M_rv.name = "M"
m_vv = M_rv.clone()
m_vv.name = "m"

M_logp = joint_logprob(M_rv, {M_rv: m_vv, I_rv: i_vv})

M_logp_fn = aesara.function([p_at, m_vv, i_vv], M_logp)

# The compiled graph should not contain any `RandomVariables`
assert_no_rvs(M_logp_fn.maker.fgraph.outputs[0])

decimals = 6 if aesara.config.floatX == "float64" else 4

test_val_rng = np.random.RandomState(3238)

p_val = 0.8
size = ()
for i in range(10):
bern_sp = sp.bernoulli(p_val)
i_val = bern_sp.rvs(size=size, random_state=test_val_rng)

norm_sp = sp.norm(loc=0, scale=1)
x_val = norm_sp.rvs(size=size, random_state=test_val_rng)

gamma_sp = sp.gamma(0.5, 0.5)
y_val = gamma_sp.rvs(size=size, random_state=test_val_rng)

if i_val == 0:
exp_obs_logps = norm_sp.logpdf(x_val)
m_val = x_val
else:
exp_obs_logps = gamma_sp.logpdf(y_val)
m_val = y_val

exp_obs_logps += bern_sp.logpmf(i_val)

logp_vals = M_logp_fn(p_val, m_val, i_val)

np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)


@aesara.config.change_flags(compute_test_value="raise")
def test_hetero_mixture_nonscalar():

X_rv = at.random.normal(0, 1, size=2, name="X")
Y_rv = at.random.gamma(0.5, 0.5, size=2, name="Y")

p_at = at.vector("p")
p_at.tag.test_value = np.r_[0.1, 0.9]

I_rv = at.random.bernoulli(p_at, size=2, name="I")
i_vv = I_rv.clone()
i_vv.name = "i"

M_rv = at.stack([X_rv, Y_rv])[I_rv]
M_rv.name = "M"

m_vv = M_rv.clone()
m_vv.name = "m"

M_logp = joint_logprob(M_rv, {M_rv: m_vv, I_rv: i_vv})

M_logp_fn = aesara.function([p_at, m_vv, i_vv], M_logp)

assert_no_rvs(M_logp_fn.maker.fgraph.outputs[0])

raise AssertionError("Not finished")

0 comments on commit 13ccc2c

Please sign in to comment.