From fa40b9bc09bf9464e3d1bede37aa3a8876b3a8b9 Mon Sep 17 00:00:00 2001 From: kc611 Date: Fri, 15 Oct 2021 16:31:35 +0530 Subject: [PATCH] Add Numba support for IfElse Op --- aesara/link/numba/dispatch/basic.py | 30 ++++++++++++++ tests/link/test_numba.py | 63 ++++++++++++++++++++++++++++- 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index 883ad2b20b..d078e2ddff 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -16,6 +16,7 @@ from aesara.graph.basic import Apply from aesara.graph.fg import FunctionGraph from aesara.graph.type import Type +from aesara.ifelse import IfElse from aesara.link.utils import ( compile_function_src, fgraph_to_python, @@ -682,3 +683,32 @@ def batched_dot(x, y): # NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because # they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM # optimizations are apparently already performed by Numba + + +@numba_funcify.register(IfElse) +def numba_funcify_IfElse(op, **kwargs): + n_outs = op.n_outs + + if n_outs > 1: + + @numba.njit + def ifelse(cond, *args): + if cond: + res = args[:n_outs] + else: + res = args[n_outs:] + + return res + + else: + + @numba.njit + def ifelse(cond, *args): + if cond: + res = args[:n_outs] + else: + res = args[n_outs:] + + return res[0] + + return ifelse diff --git a/tests/link/test_numba.py b/tests/link/test_numba.py index e31d8db443..6504182e7a 100644 --- a/tests/link/test_numba.py +++ b/tests/link/test_numba.py @@ -22,9 +22,10 @@ from aesara.compile.sharedvalue import SharedVariable from aesara.graph.basic import Apply, Constant from aesara.graph.fg import FunctionGraph -from aesara.graph.op import Op +from aesara.graph.op import Op, get_test_value from aesara.graph.optdb import OptimizationQuery from aesara.graph.type import Type +from aesara.ifelse import ifelse from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.linker import NumbaLinker from aesara.scalar.basic import Composite @@ -3076,3 +3077,63 @@ def power_of_2(previous_power, max_value): np.array(45).astype(config.floatX), ] compare_numba_and_py(out_fg, test_input_vals) + + +@pytest.mark.parametrize( + "inputs, cond_fn, true_vals, false_vals", + [ + ([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]), + ( + [set_test_value(aet.dscalar(), np.array(0.2, dtype=np.float64))], + lambda x: x < 0.5, + np.r_[1, 2, 3], + np.r_[-1, -2, -3], + ), + ( + [ + set_test_value(aet.dscalar(), np.array(0.3, dtype=np.float64)), + set_test_value(aet.dscalar(), np.array(0.5, dtype=np.float64)), + ], + lambda x, y: x > y, + x, + y, + ), + ( + [ + set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + ], + lambda x, y: aet.all(x > y), + x, + y, + ), + ( + [ + set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + ], + lambda x, y: aet.all(x > y), + [x, 2 * x], + [y, 3 * y], + ), + ( + [ + set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + ], + lambda x, y: aet.all(x > y), + [x, 2 * x], + [y, 3 * y], + ), + ], +) +def test_numba_ifelse(inputs, cond_fn, true_vals, false_vals): + + out = ifelse(cond_fn(*inputs), true_vals, false_vals) + + if not isinstance(out, list): + out = [out] + + out_fg = FunctionGraph(inputs, out) + + compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])