Skip to content

Commit

Permalink
Implement a general mixture log-probability rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 11, 2021
1 parent edff82c commit 33bc7d5
Show file tree
Hide file tree
Showing 5 changed files with 453 additions and 6 deletions.
12 changes: 9 additions & 3 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
del get_versions


from .logprob import logprob # isort: split
from aeppl.logprob import logprob # isort: split

from .joint_logprob import joint_logprob
from .printing import latex_pprint, pprint
from aeppl.joint_logprob import joint_logprob
from aeppl.printing import latex_pprint, pprint

# isort: off
# Add optimizations to the DBs
import aeppl.mixture

# isort: on
237 changes: 237 additions & 0 deletions aeppl/mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from typing import List, Optional

import aesara.tensor as at
import numpy as np
from aesara.compile.builders import OpFromGraph
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer, pre_greedy_local_optimizer
from aesara.ifelse import ifelse
from aesara.tensor.basic import Join, MakeVector
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import local_dimshuffle_rv_lift, local_subtensor_rv_lift
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable

from aeppl.logprob import _logprob, logprob
from aeppl.opt import naive_bcast_rv_lift, rv_sinking_db, subtensor_ops
from aeppl.utils import get_constant_value, indices_from_subtensor


def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
"""Pull a ``RandomVariable`` ``Op`` down through a graph, when possible."""
fgraph = FunctionGraph(outputs=dont_touch_vars or [], clone=False)

return pre_greedy_local_optimizer(
fgraph,
[
local_dimshuffle_rv_lift,
local_subtensor_rv_lift,
naive_bcast_rv_lift,
],
x,
)


class MixtureRV(OpFromGraph):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""

default_output = 1
# FIXME: This is just to appease `random_make_inplace`
inplace = True

@classmethod
def create_node(cls, node, indices, mixture_rvs):
out_var = node.default_output()

inputs = list(indices) + list(mixture_rvs)

mixture_op = cls(
# The first and third parameters are simply placeholders so that the
# arguments signature matches `RandomVariable`'s
inputs,
[NoneConst, out_var],
inline=True,
on_unused_input="ignore",
)

# Give this composite `Op` a `RandomVariable`-like interface
mixture_op.name = f"{out_var.name if out_var.name else ''}-mixture"
mixture_op.ndim_supp = out_var.ndim
mixture_op.dtype = out_var.dtype
mixture_op.ndims_params = [inp.ndim for inp in inputs]

# new_node = mixture_op.make_node(None, None, None, *inputs)
new_node = mixture_op(*inputs)

return new_node.owner

def make_node(self, *inputs):
# Make the `make_node` signature consistent with the node inputs
# TODO: This is a hack; make it less so.
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
return super().make_node(*inputs[:num_expected_inps])

def rng_fn(self, rng, *args, **kwargs):
raise NotImplementedError()

def get_non_shared_inputs(self, inputs):
return inputs[: len(self.shared_inputs)]


# Allow `MixtureRV`s to be typed as `RandomVariable`s
RandomVariable.register(MixtureRV)


def get_stack_mixture_vars(
node: Apply,
) -> Optional[List[TensorVariable]]:
r"""Extract the mixture terms from a `*Subtensor*` applied to stacked `RandomVariable`\s."""
if not isinstance(node.op, subtensor_ops):
return # noqa

joined_rvs = node.inputs[0]

# First, make sure that it's some sort of concatenation
if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, (MakeVector, Join))):
# Node is not a compatible join `Op`
return # noqa

if isinstance(joined_rvs.owner.op, MakeVector):
mixture_rvs = joined_rvs.owner.inputs

elif isinstance(joined_rvs.owner.op, Join):
mixture_rvs = joined_rvs.owner.inputs[1:]
join_axis = joined_rvs.owner.inputs[0]
try:
join_axis = int(get_constant_value(join_axis))
except ValueError:
# TODO: Support symbolic join axes
return

if join_axis != 0:
# TODO: Support other join axes
return

if not all(
rv.owner and isinstance(rv.owner.op, RandomVariable) for rv in mixture_rvs
):
# Currently, all mixture components must be `RandomVariable` outputs
# TODO: Allow constants and make them Dirac-deltas
return

return mixture_rvs


@local_optimizer(subtensor_ops)
def mixture_replace(fgraph, node):
r"""Identify mixture sub-graphs and replace them with a place-holder `Op`.
The basic idea is to find ``stack(mixture_comps)[I_rv]``, where
``mixture_comps`` is a ``list`` of `RandomVariable`\s and ``I_rv`` is a
`RandomVariable` with a discrete and finite support.
From these terms, new terms ``Z_rv[i] = mixture_comps[i][i == I_rv]`` are
created for each ``i`` in ``enumerate(mixture_comps)``.
"""

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return # noqa

out_var = node.default_output()

if out_var not in rv_map_feature.rv_values:
return

mixture_res = get_stack_mixture_vars(node)

if mixture_res is None:
return

mixture_rvs = mixture_res

mixture_value_var = rv_map_feature.rv_values.pop(out_var, None)

# We loop through mixture components and collect all the array elements
# that belong to each one (by way of their indices).
for i, component_rv in enumerate(mixture_rvs):
if component_rv in rv_map_feature.rv_values:
raise ValueError("A value variable was specified for a mixture component")
component_rv.tag.ignore_logprob = True

# Replace this sub-graph with a `MixtureRV`
new_node = MixtureRV.create_node(node, node.inputs[1:], mixture_rvs)

new_mixture_rv = new_node.default_output()
new_mixture_rv.name = "mixture"
rv_map_feature.rv_values[new_mixture_rv] = mixture_value_var

# FIXME: This is pretty hackish
fgraph.import_node(new_node, import_missing=True, reason="mixture_rv")

return [new_mixture_rv]


@_logprob.register(MixtureRV)
def logprob_MixtureRV(op, value, *inputs, name=None, **kwargs):
inputs = op.get_non_shared_inputs(inputs)

subtensor_node = op.outputs[1].owner
num_indices = len(subtensor_node.inputs) - 1
indices = inputs[:num_indices]
indices = indices_from_subtensor(
getattr(subtensor_node.op, "idx_list", None), indices
)
comp_rvs = inputs[num_indices:]

if value.ndim > 0:
# TODO: Make the join axis to the left-most dimension (or transpose the
# problem)
join_axis = 0 # op.join_axis

logp_val = at.full(tuple(at.shape(value)), -np.inf, dtype=value.dtype)

for i, comp_rv in enumerate(comp_rvs):
I_0 = indices[join_axis]
join_indices = at.nonzero(at.eq(I_0, i))
#
# pre_index = (
# tuple(at.ogrid[tuple(slice(None, s) for s in at.shape(join_indices))])
# if I_0 is not None
# else (slice(None),)
# )
#
# non_join_indices = pre_index + indices[1:]
#
# obs_i = value[join_indices][non_join_indices]
obs_i = value[join_indices]

bcast_shape = at.broadcast_shape(
tuple(value.shape), tuple(comp_rv.shape), arrays_are_shapes=True
)
bcasted_comp_rv = at.broadcast_to(comp_rv, bcast_shape)
zz = bcasted_comp_rv[join_indices]
indexed_comp_rv = rv_pull_down(zz, inputs)
# indexed_comp_rv = rv_pull_down(indexed_comp_rv[non_join_indices], inputs)

logp_val = at.set_subtensor(
# logp_val[join_indices][non_join_indices],
logp_val[join_indices],
logprob(indexed_comp_rv, obs_i),
)

else:
logp_val = 0.0
for i, comp_rv in enumerate(comp_rvs):
comp_logp = logprob(comp_rv, value)
logp_val += ifelse(
at.eq(indices[0], i),
comp_logp,
at.as_tensor(0.0, dtype=comp_logp.dtype),
)

return logp_val


rv_sinking_db.register("mixture_replace", mixture_replace, -5, "basic")
14 changes: 13 additions & 1 deletion aeppl/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
from aesara.tensor.var import TensorVariable

from aeppl.utils import indices_from_subtensor

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
subtensor_ops = (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)


class PreserveRVMappings(Feature):
Expand Down Expand Up @@ -141,7 +145,15 @@ def naive_bcast_rv_lift(fgraph, node):

rng, size, dtype, *dist_params = lifted_node.inputs

new_dist_params = [at.broadcast_to(param, bcast_shape) for param in dist_params]
new_dist_params = [
at.broadcast_to(
param,
at.broadcast_shape(
tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True
),
)
for param in dist_params
]
bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params)

if aesara.config.compute_test_value != "off":
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"numpy>=1.18.1",
"scipy>=1.4.0",
"numba",
"aesara >= 2.0.12",
"aesara >= 2.1.0",
],
tests_require=["pytest"],
long_description=open("README.md").read() if exists("README.md") else "",
Expand Down
Loading

0 comments on commit 33bc7d5

Please sign in to comment.