Skip to content

Commit

Permalink
Merge branch 'main' into fix_dimshuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Oct 22, 2021
2 parents d4dd922 + 0f8c81c commit cda0357
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
10 changes: 10 additions & 0 deletions aesara/link/numba/dispatch/scalar.py
Expand Up @@ -21,6 +21,7 @@
Clip,
Composite,
Identity,
Inv,
Mul,
ScalarOp,
Second,
Expand Down Expand Up @@ -170,3 +171,12 @@ def second(x, y):
return y

return second


@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
@numba.njit(inline="always")
def inv(x):
return 1 / x

return inv
19 changes: 19 additions & 0 deletions tests/link/test_numba.py
Expand Up @@ -804,6 +804,25 @@ def test_Cast(v, dtype):
)


@pytest.mark.parametrize(
"v, dtype",
[
(set_test_value(aet.iscalar(), np.array(10, dtype="int32")), aesb.float64),
],
)
def test_Inv(v, dtype):
g = aesb.inv(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


@pytest.mark.parametrize(
"v, shape, ndim",
[
Expand Down

0 comments on commit cda0357

Please sign in to comment.