Skip to content

Commit

Permalink
Add Numba support for IfElse Op
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 authored and brandonwillard committed Oct 19, 2021
1 parent c2e3dbb commit fa40b9b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
30 changes: 30 additions & 0 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.ifelse import IfElse
from aesara.link.utils import (
compile_function_src,
fgraph_to_python,
Expand Down Expand Up @@ -682,3 +683,32 @@ def batched_dot(x, y):
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba


@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs

if n_outs > 1:

@numba.njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]

return res

else:

@numba.njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]

return res[0]

return ifelse
63 changes: 62 additions & 1 deletion tests/link/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
from aesara.ifelse import ifelse
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
Expand Down Expand Up @@ -3076,3 +3077,63 @@ def power_of_2(previous_power, max_value):
np.array(45).astype(config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)


@pytest.mark.parametrize(
"inputs, cond_fn, true_vals, false_vals",
[
([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]),
(
[set_test_value(aet.dscalar(), np.array(0.2, dtype=np.float64))],
lambda x: x < 0.5,
np.r_[1, 2, 3],
np.r_[-1, -2, -3],
),
(
[
set_test_value(aet.dscalar(), np.array(0.3, dtype=np.float64)),
set_test_value(aet.dscalar(), np.array(0.5, dtype=np.float64)),
],
lambda x, y: x > y,
x,
y,
),
(
[
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
x,
y,
),
(
[
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
(
[
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
],
)
def test_numba_ifelse(inputs, cond_fn, true_vals, false_vals):

out = ifelse(cond_fn(*inputs), true_vals, false_vals)

if not isinstance(out, list):
out = [out]

out_fg = FunctionGraph(inputs, out)

compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])

0 comments on commit fa40b9b

Please sign in to comment.