Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Dimshuffle for scalar result cases #625

Merged
merged 2 commits into from Oct 22, 2021

Conversation

kc611
Copy link
Member

@kc611 kc611 commented Oct 21, 2021

Fixes #621

A simple fix which returns a scalar array for the particular case of DimShuffle

import numpy as np
import aesara
from aesara.graph.optdb import OptimizationQuery
from aesara.compile.mode import Mode
from aesara.link.numba.linker import NumbaLinker
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.type import TensorType

opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)

orig_shape = (1, 1, 1)
dims = ()
out_shape = ()

broadcastables = [(entry == 1) for entry in orig_shape]
x = TensorType(aesara.config.floatX, broadcastables)("x")
out = DimShuffle(broadcastables, dims)(x)

py_res = aesara.function(inputs=[x], outputs=[out], mode=py_mode)(np.ones(orig_shape, dtype=aesara.config.floatX)) # returns array([1.])
numba_res = aesara.function(inputs=[x], outputs=[out], mode=numba_mode)(np.ones(orig_shape, dtype=aesara.config.floatX)) # returns array([1.])
pass

@brandonwillard brandonwillard added bug Something isn't working important labels Oct 21, 2021
@codecov
Copy link

codecov bot commented Oct 21, 2021

Codecov Report

Merging #625 (cda0357) into main (0f8c81c) will increase coverage by 0.00%.
The diff coverage is 86.66%.

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #625   +/-   ##
=======================================
  Coverage   77.02%   77.02%           
=======================================
  Files         157      157           
  Lines       46915    46918    +3     
  Branches    10265    10266    +1     
=======================================
+ Hits        36136    36139    +3     
  Misses       8196     8196           
  Partials     2583     2583           
Impacted Files Coverage Δ
aesara/link/numba/dispatch/elemwise.py 97.54% <86.66%> (+0.03%) ⬆️

@brandonwillard brandonwillard merged commit 263a7c7 into aesara-devs:main Oct 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working important
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DimShuffle's Numba implementation cannot handle cases when output is empty/scalar.
2 participants