Skip to content

Commit

Permalink
Fixed Dimshuffle for scalar result cases
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 authored and brandonwillard committed Oct 22, 2021
1 parent 0f8c81c commit 263a7c7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 15 deletions.
40 changes: 25 additions & 15 deletions aesara/link/numba/dispatch/elemwise.py
Expand Up @@ -324,7 +324,6 @@ def numba_funcify_DimShuffle(op, **kwargs):
inplace = op.inplace

ndim_new_shape = len(shuffle) + len(augment)
create_zeros_tuple = numba_basic.create_tuple_creator(lambda _: 0, ndim_new_shape)

if len(shuffle) > 0:

Expand All @@ -346,24 +345,35 @@ def populate_new_shape(i, j, new_shape, shuffle_shape):
def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, tuple_setitem(new_shape, i, 1)

@numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
shuffle_shape = res.shape[: len(shuffle)]
if ndim_new_shape > 0:
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, ndim_new_shape
)

new_shape = create_zeros_tuple()
@numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
shuffle_shape = res.shape[: len(shuffle)]

j = 0
for i in range(len(new_shape)):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)
new_shape = create_zeros_tuple()

# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)
j = 0
for i in range(len(new_shape)):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)

if not inplace:
return res_reshape.copy()
else:
return res_reshape
# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)

if not inplace:
return res_reshape.copy()
else:
return res_reshape

else:

@numba.njit
def dimshuffle_inner(x, shuffle):
return x.item()

# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
Expand Down
8 changes: 8 additions & 0 deletions tests/link/test_numba.py
Expand Up @@ -691,6 +691,14 @@ def test_AllocDiag(v, offset):
(0,),
True,
),
(
set_test_value(
aet.tensor(config.floatX, [True, True, True], name="a"),
np.array([[[1.0]]], dtype=config.floatX),
),
(),
True,
),
],
)
def test_Dimshuffle(v, new_order, inplace):
Expand Down

0 comments on commit 263a7c7

Please sign in to comment.