Skip to content

Commit

Permalink
Use the correct dtype object in numba_funcify_CAReduce
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 23, 2021
1 parent 248ce6d commit 1366221
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions aesara/link/numba/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,24 +519,22 @@ def numba_funcify_CAReduce(op, node, **kwargs):

scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)

acc_dtype = numba.np.numpy_support.from_dtype(np_acc_dtype)

scalar_nfunc_spec = op.scalar_op.nfunc_spec

# We construct a dummy `Apply` that has the minimum required number of
# inputs for the scalar `Op`. Without this, we would get a scalar function
# with too few arguments.
dummy_node = Apply(
op,
[tensor(acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
[tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
elemwise_fn = numba_funcify_Elemwise(op, dummy_node, use_signature=True, **kwargs)

input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer(
elemwise_fn, scalar_op_identity, axes, ndim, acc_dtype, input_name=input_name
elemwise_fn, scalar_op_identity, axes, ndim, np_acc_dtype, input_name=input_name
)

return numba.njit(careduce_fn)
Expand Down

0 comments on commit 1366221

Please sign in to comment.