Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DimShuffle's Numba implementation cannot handle cases when output is empty/scalar. #621

Closed
kc611 opened this issue Oct 19, 2021 · 0 comments · Fixed by #625
Closed

DimShuffle's Numba implementation cannot handle cases when output is empty/scalar. #621

kc611 opened this issue Oct 19, 2021 · 0 comments · Fixed by #625
Labels
bug Something isn't working important Numba Involves Numba transpilation

Comments

@kc611
Copy link
Member

kc611 commented Oct 19, 2021

Description of your problem or feature request

DimShuffle's Numba implementation fails when output shape is () (i.e. no output when all dimensions are dropped). This happens because of the n > 0 restriction placed on create_tuple_creator:

def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop.
See https://github.com/numba/numba/issues/2771#issuecomment-414358902
"""
assert n > 0

Which is used in the implementation of DimShuffle to create output 's shape tuple:

create_zeros_tuple = numba_basic.create_tuple_creator(lambda _: 0, ndim_new_shape)

Please provide a minimal, self-contained, and reproducible example.

import numpy as np
import aesara
from aesara.graph.optdb import OptimizationQuery
from aesara.compile.mode import Mode
from aesara.link.numba.linker import NumbaLinker
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.type import TensorType

opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)

orig_shape = (1, 1, 1)
dims = ()
out_shape = ()

broadcastables = [(entry == 1) for entry in orig_shape]
x = TensorType(aesara.config.floatX, broadcastables)("x")
out = DimShuffle(broadcastables, dims)(x)

py_res = aesara.function(inputs=[x], outputs=[out], mode=py_mode)(np.ones(orig_shape, dtype=aesara.config.floatX)) # Passes
numba_res = aesara.function(inputs=[x], outputs=[out], mode=numba_mode)(np.ones(orig_shape, dtype=aesara.config.floatX)) # Fails
pass
@brandonwillard brandonwillard added bug Something isn't working important Numba Involves Numba transpilation labels Oct 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working important Numba Involves Numba transpilation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants