Skip to content

Commit

Permalink
Implement a general mixture log-probability rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 11, 2021
1 parent edff82c commit b81d998
Show file tree
Hide file tree
Showing 4 changed files with 363 additions and 5 deletions.
6 changes: 6 additions & 0 deletions aeppl/joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ def joint_logprob(
node.op, q_rv_value_var, *value_var_inputs, name=q_rv_var.name, **kwargs
)

if q_logprob_var is None:
# When a `_logprob` returns `None` it signifies that the node
# can be skipped, because, for example, its log-likelihood is
# determined independently by its inputs.
continue

if q_rv_var.name:
q_logprob_var.name = f"{q_rv_var.name}_logprob"

Expand Down
257 changes: 253 additions & 4 deletions aeppl/opt.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Dict
from typing import Dict, List, Optional

import aesara
import aesara.tensor as at
import numpy as np
from aesara.compile.builders import OpFromGraph
from aesara.compile.mode import optdb
from aesara.graph.basic import Apply
from aesara.graph.features import Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt import local_optimizer
from aesara.graph.opt import local_optimizer, pre_greedy_local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.ifelse import ifelse
from aesara.tensor.basic import Join, MakeVector
from aesara.tensor.extra_ops import BroadcastTo
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.opt import (
Expand All @@ -17,13 +23,19 @@
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
)
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable

from aeppl.utils import indices_from_subtensor
from aeppl.logprob import _logprob, logprob
from aeppl.utils import get_constant_value, indices_from_subtensor

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
subtensor_ops = (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)


class PreserveRVMappings(Feature):
Expand Down Expand Up @@ -141,7 +153,15 @@ def naive_bcast_rv_lift(fgraph, node):

rng, size, dtype, *dist_params = lifted_node.inputs

new_dist_params = [at.broadcast_to(param, bcast_shape) for param in dist_params]
new_dist_params = [
at.broadcast_to(
param,
at.broadcast_shape(
tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True
),
)
for param in dist_params
]
bcasted_node = lifted_node.op.make_node(rng, size, dtype, *new_dist_params)

if aesara.config.compute_test_value != "off":
Expand All @@ -150,11 +170,240 @@ def naive_bcast_rv_lift(fgraph, node):
return [bcasted_node.outputs[1]]


class MixtureRV(OpFromGraph):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""

default_output = 1
# FIXME: This is just to appease `random_make_inplace`
inplace = True

@classmethod
def create_node(cls, node, indices, mixture_rvs):
out_var = node.default_output()

inputs = list(indices) + list(mixture_rvs)
# mixture_op = cls(
# f"{out_var.name if out_var.name else ''}-mixture",
# out_var.ndim,
# [inp.ndim for inp in inputs],
# out_var.dtype,
# inplace=False,
# )

mixture_op = cls(
# The first and third parameters are simply placeholders so that the
# arguments signature matches `RandomVariable`'s
inputs,
[NoneConst, out_var],
inline=True,
on_unused_input="ignore",
)

# Give this composite `Op` a `RandomVariable`-like interface
mixture_op.name = f"{out_var.name if out_var.name else ''}-mixture"
mixture_op.ndim_supp = out_var.ndim
mixture_op.dtype = out_var.dtype
mixture_op.ndims_params = [inp.ndim for inp in inputs]

# new_node = mixture_op.make_node(None, None, None, *inputs)
new_node = mixture_op(*inputs)

return new_node.owner

def make_node(self, *inputs):
# Make the `make_node` signature consistent with the node inputs
# TODO: This is a hack; make it less so.
num_expected_inps = len(self.local_inputs) - len(self.shared_inputs)
if len(inputs) > num_expected_inps:
inputs = inputs[:num_expected_inps]
return super().make_node(*inputs)

def rng_fn(self, rng, *args, **kwargs):
raise NotImplementedError()

def get_non_shared_inputs(self, inputs):
return inputs[: len(self.shared_inputs)]


# Allow `MixtureRV`s to be typed as `RandomVariable`s
RandomVariable.register(MixtureRV)


def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
"""Pull a ``RandomVariable`` ``Op`` down through a graph, when possible."""
if dont_touch_vars is None:
dont_touch_vars = []

fgraph = FunctionGraph(outputs=dont_touch_vars, clone=False)

return pre_greedy_local_optimizer(
fgraph,
[
local_dimshuffle_rv_lift,
local_subtensor_rv_lift,
naive_bcast_rv_lift,
],
x,
)


def get_stack_mixture_vars(
node: Apply,
) -> Optional[List[TensorVariable]]:
r"""Extract the mixture terms from a `*Subtensor*` applied to stacked `RandomVariable`\s."""
if not isinstance(node.op, subtensor_ops):
return

joined_rvs = node.inputs[0]

# First, make sure that it's some sort of combination
if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, (MakeVector, Join))):
# Node is not a compatible join `Op`
return None

if isinstance(joined_rvs.owner.op, MakeVector):
mixture_rvs = joined_rvs.owner.inputs

elif isinstance(joined_rvs.owner.op, Join):
mixture_rvs = joined_rvs.owner.inputs[1:]
join_axis = joined_rvs.owner.inputs[0]
try:
join_axis = int(get_constant_value(join_axis))
except ValueError:
# TODO: Support symbolic join axes
return None

if join_axis != 0:
# TODO: Support other join axes
return None

if not all(
rv.owner and isinstance(rv.owner.op, RandomVariable) for rv in mixture_rvs
):
# Currently, all mixture components must be `RandomVariable` outputs
# TODO: Allow constants and make them Dirac-deltas
return None

return mixture_rvs


@local_optimizer(subtensor_ops)
def mixture_replace(fgraph, node):
r"""Identify mixture sub-graphs and replace them with a place-holder `Op`.
The basic idea is to find ``stack(mixture_comps)[I_rv]``, where
``mixture_comps`` is a ``list`` of `RandomVariable`\s and ``I_rv`` is a
`RandomVariable` with a discrete and finite support.
From these terms, new terms ``Z_rv[i] = mixture_comps[i][i == I_rv]`` are
created for each ``i`` in ``enumerate(mixture_comps)``.
"""

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return

out_var = node.default_output()

if out_var not in rv_map_feature.rv_values:
return

mixture_res = get_stack_mixture_vars(node)

if mixture_res is None:
return

mixture_rvs = mixture_res

mixture_value_var = rv_map_feature.rv_values.pop(out_var, None)

# We loop through mixture components and collect all the array elements
# that belong to each one (by way of their indices).
for i, component_rv in enumerate(mixture_rvs):
if component_rv in rv_map_feature.rv_values:
raise ValueError("A value variable was specified for a mixture component")
component_rv.tag.ignore_logprob = True

# Replace this sub-graph with a `MixtureRV`
new_node = MixtureRV.create_node(node, node.inputs[1:], mixture_rvs)

new_mixture_rv = new_node.default_output()
new_mixture_rv.name = "mixture"
rv_map_feature.rv_values[new_mixture_rv] = mixture_value_var

# FIXME: This is pretty hackish
fgraph.import_node(new_node, import_missing=True, reason="mixture_rv")

return [new_mixture_rv]


@_logprob.register(MixtureRV)
def logprob_MixtureRV(op, value, *inputs, name=None, **kwargs):
inputs = op.get_non_shared_inputs(inputs)

subtensor_node = op.outputs[1].owner
num_indices = len(subtensor_node.inputs) - 1
indices = inputs[:num_indices]
indices = indices_from_subtensor(
getattr(subtensor_node.op, "idx_list", None), indices
)
comp_rvs = inputs[num_indices:]

if value.ndim > 0:
# TODO: Make the join axis to the left-most dimension (or transpose the
# problem)
join_axis = 0 # op.join_axis

logp_val = at.full(tuple(at.shape(value)), -np.inf, dtype=value.dtype)

for i, comp_rv in enumerate(comp_rvs):
I_0 = indices[join_axis]
join_indices = at.nonzero(at.eq(I_0, i))
#
# pre_index = (
# tuple(at.ogrid[tuple(slice(None, s) for s in at.shape(join_indices))])
# if I_0 is not None
# else (slice(None),)
# )
#
# non_join_indices = pre_index + indices[1:]
#
# obs_i = value[join_indices][non_join_indices]
obs_i = value[join_indices]

bcast_shape = at.broadcast_shape(
tuple(value.shape), tuple(comp_rv.shape), arrays_are_shapes=True
)
bcasted_comp_rv = at.broadcast_to(comp_rv, bcast_shape)
zz = bcasted_comp_rv[join_indices]
indexed_comp_rv = rv_pull_down(zz, inputs)
# indexed_comp_rv = rv_pull_down(indexed_comp_rv[non_join_indices], inputs)

logp_val = at.set_subtensor(
# logp_val[join_indices][non_join_indices],
logp_val[join_indices],
logprob(indexed_comp_rv, obs_i),
)

else:
logp_val = 0.0
for i, comp_rv in enumerate(comp_rvs):
comp_logp = logprob(comp_rv, value)
logp_val += ifelse(
at.eq(indices[0], i),
comp_logp,
at.as_tensor(0.0, dtype=comp_logp.dtype),
)

return logp_val


logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.register("canonicalize", optdb["canonicalize"], -10, "basic")
rv_sinking_db = EquilibriumDB()
rv_sinking_db.register("dimshuffle_lift", local_dimshuffle_rv_lift, -5, "basic")
rv_sinking_db.register("subtensor_lift", local_subtensor_rv_lift, -5, "basic")
rv_sinking_db.register("broadcast_to_lift", naive_bcast_rv_lift, -5, "basic")
rv_sinking_db.register("incsubtensor_lift", incsubtensor_rv_replace, -5, "basic")
rv_sinking_db.register("mixture_replace", mixture_replace, -5, "basic")
logprob_rewrites_db.register("sinking", rv_sinking_db, -10, "basic")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"numpy>=1.18.1",
"scipy>=1.4.0",
"numba",
"aesara >= 2.0.12",
"aesara >= 2.1.0",
],
tests_require=["pytest"],
long_description=open("README.md").read() if exists("README.md") else "",
Expand Down
Loading

0 comments on commit b81d998

Please sign in to comment.