Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 0 additions & 99 deletions effectful/internals/disjoint_set.py

This file was deleted.

205 changes: 113 additions & 92 deletions effectful/ops/monoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from graphlib import TopologicalSorter
from typing import Annotated, Any

from effectful.internals.disjoint_set import DisjointSet
from effectful.ops.semantics import (
coproduct,
evaluate,
Expand Down Expand Up @@ -62,6 +61,51 @@ def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams]
)


def inner_stream(
streams: dict[Operation, Expr],
) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]:
"""Returns the streams that can be ordered innermost in the loop nest as
well as the remaining streams in the nest.

"""
stream_vars = set(streams.keys())

no_dependents = set()
succ = defaultdict(set)
for k, v in streams.items():
preds = fvsof(v) & stream_vars
if preds:
for pred in preds:
succ[pred].add(k)
else:
no_dependents.add(k)

topo = TopologicalSorter(succ)
topo.prepare()
return (
({k: v for (k, v) in streams.items() if k != op}, op, streams[op])
for op in set(topo.get_ready()) | no_dependents
)


def inner_streams_first(streams: dict[Operation, Expr]) -> Iterable[Operation]:
"""Iterable over streams where dependent streams precede their dependencies."""
stream_vars = set(streams.keys())

no_dependents = set()
succ = defaultdict(set)
for k, v in streams.items():
preds = fvsof(v) & stream_vars
if preds:
for pred in preds:
succ[pred].add(k)
else:
no_dependents.add(k)

topo = TopologicalSorter(succ)
return topo.static_order()


class Monoid[W]:
"""A monoid with ``plus`` and ``reduce`` :class:`Operation` s."""

Expand Down Expand Up @@ -392,110 +436,87 @@ def reduce(self, monoid, body, streams):


class ReduceFactorization(ObjectInterpretation):
"""
Implements factorization of independent terms.
For example, when having two independent distributions,
we can rewrite their marginalization as:
∫p(x)⋅q(y)dxdy => ∫p(x)dx ⋅ ∫q(y)dy

More specifically, in terms of reduces we are performing:
reduce(R, (S₁ × ... × Sₖ) , A₁ * ... * Aₖ)
=> reduce(R, S₁, A₁) * ... * reduce(R, Sₖ, Aₖ)
where free(Aᵢ) ∩ free(Aⱼ) ∩ S = ∅
and free(Aᵢ) ∩ S ⊆ Sᵢ
"""reduce(⊗(F_v ∪ F_rest), {v} ∪ S) = reduce(⊗F_rest ⊗ reduce(⊗F_v, {v}), S)

where F_v = factors mentioning v, F_rest = the others. Fires only when
v has no dependents among the remaining streams (so it can be innermost)
and F_rest is nonempty (universal variables stay in the outer core).
"""

@implements(Monoid.reduce)
def reduce(self, monoid, body, streams):
if not is_commutative(monoid):
return fwd()
if (
isinstance(body, Term)
if not (
is_commutative(monoid)
and isinstance(body, Term)
and _is_monoid_plus(body.op)
and distributes_over(body.op.__self__, monoid)
):
inner_monoid: Monoid = body.op.__self__
stream_vars = set(streams.keys())
factors = [(arg, fvsof(arg)) for arg in body.args]
stream_ids = {v: i for (i, v) in enumerate(stream_vars)}
ds = DisjointSet(len(streams))

# streams are in the same partition as their dependencies
for stream_var, stream_id in stream_ids.items():
stream_body = streams[stream_var]
deps = sorted([stream_ids[v] for v in fvsof(stream_body) & stream_vars])
ds.union(stream_id, *deps)

# factors are in the same partition as their dependencies
for _, factor_fvs in factors:
factor_streams = sorted(
[stream_ids[v] for v in (factor_fvs & stream_vars)]
)
ds.union(*factor_streams)

placed_streams = set()
new_reduces = []
for stream_key in streams:
if stream_key in placed_streams:
continue

partition = ds.find(stream_ids[stream_key])
partition_streams = {
k: v
for (k, v) in streams.items()
if ds.find(stream_ids[k]) == partition
}
partition_stream_keys = set(partition_streams.keys())

partition_factors = [
t for t in factors if (t[1] & partition_stream_keys)
]

assert all(
(t[1] & stream_vars) <= partition_stream_keys
for t in partition_factors
), "partition contains all streams required by factor"

partition_term = inner_monoid.plus(*(t[0] for t in partition_factors))
new_reduces.append((partition_term, partition_streams))
placed_streams |= partition_stream_keys

constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)]

if len(new_reduces) > 1:
result = inner_monoid.plus(
*constant_factors, *(monoid.reduce(*args) for args in new_reduces)
)
return result
return fwd()

return fwd()
inner = body.op.__self__
stream_keys = set(streams)
factors = [(a, fvsof(a)) for a in body.args]

# candidates: innermost-eligible (no remaining stream depends on v),
# non-universal (some factor doesn't mention v)
support: dict = {}
for v in streams:
if any(v in fvsof(s) for k, s in streams.items() if k is not v):
continue
f_v = frozenset(i for i, (_, fvs) in enumerate(factors) if v in fvs)
if len(f_v) == len(factors):
continue # v is universal: leave it in the outer core
support[v] = f_v

# eliminate a variable with subset-minimal factor support
# (leaves-first; canonical on hierarchical/laminar supports)
inner_stream = None
inner_factor_ids = None
for v, f_v in support.items():
if any(u_sup < f_v for u, u_sup in support.items() if u is not v):
continue
inner_stream = v
inner_factor_ids = f_v
break

def inner_stream(
streams: dict[Operation, Expr],
) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]:
"""Returns the streams that can be ordered innermost in the loop nest as
well as the remaining streams in the nest.
if not inner_stream or not inner_factor_ids:
return fwd()

"""
stream_vars = set(streams.keys())
inner_factors = [factors[i][0] for i in sorted(inner_factor_ids)]
inner_stream_keys = {inner_stream}
inner_deps = set().union(
*(factors[i][1] for i in f_v), fvsof(streams[v]) & stream_keys
)

no_dependents = set()
succ = defaultdict(set)
for k, v in streams.items():
preds = fvsof(v) & stream_vars
if preds:
for pred in preds:
succ[pred].add(k)
else:
no_dependents.add(k)
outer_factors = [a for i, (a, _) in enumerate(factors) if i not in f_v]
outer_stream_keys = stream_keys - inner_stream_keys
outer_factor_deps = set().union(
*(vars for i, (_, vars) in enumerate(factors) if i not in f_v)
)

topo = TopologicalSorter(succ)
topo.prepare()
return (
({k: v for (k, v) in streams.items() if k != op}, op, streams[op])
for op in set(topo.get_ready()) | no_dependents
)
# find all streams that are used in the inner factors/streams and are
# not used by the outer factors/streams
# this has to be done iteratively, because moving a stream inward
# reduces the outer dependency set
# ensures that no future factorization application creates a reduce that
# fuses with with the inner reduce
for s in inner_streams_first(streams):
outer_stream_deps = (
set().union(*(fvsof(streams[k]) for k in outer_stream_keys))
& stream_keys
)
outer_deps = outer_factor_deps | outer_stream_deps
if s in inner_deps and s not in outer_deps:
inner_stream_keys |= {s}
inner_deps |= stream_keys & fvsof(streams[s])
outer_stream_keys -= {s}

inner_streams = {k: v for (k, v) in streams.items() if k in inner_stream_keys}
inner_red = monoid.reduce(inner.plus(*inner_factors), inner_streams)

rest_streams = {k: s for k, s in streams.items() if k in outer_stream_keys}
new_body = inner.plus(*outer_factors, inner_red)
return monoid.reduce(new_body, rest_streams) if rest_streams else new_body


class ReduceDistributeCartesianProduct(ObjectInterpretation):
Expand Down
10 changes: 10 additions & 0 deletions tests/_monoid_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ def strategy(
return st.sampled_from(self._unary_num_fns)
case (builtins.int, builtins.int), "scalar":
return st.sampled_from(self._binary_num_fns)
case (builtins.int, builtins.int, builtins.int), "scalar":
return st.tuples(
st.sampled_from(self._binary_num_fns),
st.sampled_from(self._binary_num_fns),
).map(lambda fg: lambda a, b, c: fg[0](a, fg[1](b, c)))
case (builtins.int,), "stream":
return st.sampled_from(self._unary_list_fns)
raise NotImplementedError(
Expand Down Expand Up @@ -389,6 +394,11 @@ def strategy(
return st.sampled_from(self._unary_jax_scalar_fns)
case (jax.Array, jax.Array), "scalar":
return st.sampled_from(self._binary_jax_scalar_fns)
case (jax.Array, jax.Array, jax.Array), "scalar":
return st.tuples(
st.sampled_from(self._binary_jax_scalar_fns),
st.sampled_from(self._binary_jax_scalar_fns),
).map(lambda fg: lambda a, b, c: fg[0](a, fg[1](b, c)))
case (jax.Array,), "stream":
return st.sampled_from(self._unary_jax_stream_fns)

Expand Down
Loading
Loading