Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numba Scan fails when sit-sot sequences aren't full length #923

Closed
ricardoV94 opened this issue Apr 22, 2022 · 8 comments · Fixed by #1203
Closed

Numba Scan fails when sit-sot sequences aren't full length #923

ricardoV94 opened this issue Apr 22, 2022 · 8 comments · Fixed by #1203
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed important Numba Involves Numba transpilation Scan Involves the `Scan` `Op`

Comments

@ricardoV94
Copy link
Contributor

import aesara
import aesara.tensor as at
import numpy as np

k = at.iscalar("k")
A = at.vector("A")

result, _ = aesara.scan(fn=lambda prior_result, A: prior_result * A,
                              outputs_info=at.ones_like(A),
                              non_sequences=A,
                              n_steps=k)

final_result = result[-1]

power = aesara.function(inputs=[A, k], outputs=final_result, mode="NUMBA")

print(power(range(10), 2))
print(power(range(10), 4))
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
@ricardoV94 ricardoV94 added bug Something isn't working Numba Involves Numba transpilation Scan Involves the `Scan` `Op` labels Apr 22, 2022
@brandonwillard
Copy link
Member

At the very least, we should add this as a unit test, because it looks like our coverage is lacking.

@brandonwillard brandonwillard added important help wanted Extra attention is needed labels Apr 22, 2022
@kc611
Copy link
Contributor

kc611 commented May 13, 2022

I did some preliminary testing on this one, I suspect the issue lies within Subtensor + Scan optimizations somehow (or maybe the Numba interpretations of them ?). Had the following observations:

import aesara
import aesara.tensor as at
import numpy as np

k = at.iscalar("k")
A = at.vector("A")

result, _ = aesara.scan(fn=lambda prior_result, A: prior_result * A,
                              outputs_info=at.ones_like(A),
                              non_sequences=A,
                              n_steps=k)

final_result = result[-1]

power = aesara.function(inputs=[A, k], outputs=result, mode="NUMBA")
print(power(range(10), 2))
# [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
#  [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]
print(power(range(10), 4))
# [[  1.   1.   1.   1.   1.   1.   1.   1.   1.   1.]
#  [  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.]
#  [  0.   1.   4.   9.  16.  25.  36.  49.  64.  81.]
#  [  0.   1.   8.  27.  64. 125. 216. 343. 512. 729.]]

power = aesara.function(inputs=[A, k], outputs=final_result, mode="NUMBA")
print(power(range(10), 2))
# [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
print(power(range(10), 4))
# [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]

Also the inner function graph and the Numba source code generated for the Scan implementation in Aesara was the same for both the cases:

aesara.dprint(inner_fg)
#Elemwise{mul,no_inplace} [id A] 0
# |*0-<TensorType(float64, (None,))> [id B]
# |*1-<TensorType(float64, (None,))> [id C]
def scan(n_steps, auto_1645, auto_13):

    for i in range(n_steps):
        inner_args = (auto_1645[i], auto_13)
        (auto_1645[i+1], ) = numba_at_inner_func(*inner_args)

    return auto_1645

@brandonwillard
Copy link
Member

For more context, here's a print-out of the fully optimized graph used by Numba:

import aesara
import aesara.tensor as at


k = at.iscalar("k")
A = at.vector("A")

result, _ = aesara.scan(
    fn=lambda prior_result, A: prior_result * A,
    outputs_info=at.ones_like(A),
    non_sequences=A,
    n_steps=k,
)

numba_power = aesara.function(inputs=[A, k], outputs=result, mode="NUMBA")

aesara.dprint(numba_power, print_op_info=True, print_fgraph_inputs=True)
# -A [id A]
# -k [id B]
# Subtensor{int64:int64:int8} [id C] 13
#  |forall_inplace,cpu,scan_fn} [id D] 12 (outer_out_sit_sot-0)
#  | |k [id B] (n_steps)
#  | |IncSubtensor{InplaceSet;:int64:} [id E] 11 (outer_in_sit_sot-0)
#  | | |AllocEmpty{dtype='float64'} [id F] 10
#  | | | |Elemwise{Composite{(Switch(GT(i0, i1), (i1 + i0), (i1 - i0)) + i2)}}[(0, 1)] [id G] 7
#  | | | | |TensorConstant{1} [id H]
#  | | | | |Elemwise{Composite{maximum(((i0 - Switch(LT(i1, i2), i3, i2)) + i4), i5)}} [id I] 2
#  | | | | | |k [id B]
#  | | | | | |TensorConstant{1} [id J]
#  | | | | | |Elemwise{add,no_inplace} [id K] 0
#  | | | | | | |TensorConstant{1} [id L]
#  | | | | | | |k [id B]
#  | | | | | |TensorConstant{1} [id J]
#  | | | | | |TensorConstant{1} [id H]
#  | | | | | |TensorConstant{2} [id M]
#  | | | | |TensorConstant{1} [id J]
#  | | | |Shape_i{0} [id N] 1
#  | | |   |A [id A]
#  | | |Rebroadcast{(0, False)} [id O] 5
#  | | | |Alloc [id P] 3
#  | | |   |TensorConstant{(1, 1) of 1.0} [id Q]
#  | | |   |TensorConstant{1} [id J]
#  | | |   |Shape_i{0} [id N] 1
#  | | |ScalarConstant{1} [id R]
#  | |A [id A] (outer_in_non_seqs-0)
#  |ScalarFromTensor [id S] 8
#  | |Elemwise{Composite{(((Switch(LT(i0, i1), i2, i1) - i3) - i4) + i5)}} [id T] 4
#  |   |TensorConstant{1} [id J]
#  |   |Elemwise{add,no_inplace} [id K] 0
#  |   |TensorConstant{1} [id J]
#  |   |k [id B]
#  |   |TensorConstant{1} [id H]
#  |   |Elemwise{Composite{maximum(((i0 - Switch(LT(i1, i2), i3, i2)) + i4), i5)}} [id I] 2
#  |ScalarFromTensor [id U] 9
#  | |Elemwise{Composite{(((i0 - i1) - i2) + i3)}}[(0, 0)] [id V] 6
#  |   |Elemwise{add,no_inplace} [id K] 0
#  |   |k [id B]
#  |   |TensorConstant{1} [id H]
#  |   |Elemwise{Composite{maximum(((i0 - Switch(LT(i1, i2), i3, i2)) + i4), i5)}} [id I] 2
#  |ScalarConstant{1} [id W]
#
# Inner graphs:
#
# forall_inplace,cpu,scan_fn} [id D] (outer_out_sit_sot-0)
# -*0-<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0)
# -*1-<TensorType(float64, (None,))> [id Y] -> [id A] (inner_in_non_seqs-0)
#  >Elemwise{mul,no_inplace} [id Z] (inner_out_sit_sot-0)
#  > |*0-<TensorType(float64, (None,))> [id X] (inner_in_sit_sot-0)
#  > |*1-<TensorType(float64, (None,))> [id Y] (inner_in_non_seqs-0)

c_power = aesara.function(inputs=[A, k], outputs=result)

aesara.graph.basic.equal_computations(
    c_power.maker.fgraph.outputs,
    numba_power.maker.fgraph.outputs,
    c_power.maker.fgraph.inputs,
    numba_power.maker.fgraph.inputs,
)
# True

The equal_computations result implies that both the C and Numba graphs are the same, though.

@brandonwillard
Copy link
Member

Also, after running the Numba version a few times in the same Python session, then doing some unrelated random things, I get low-level crashes with the occasional error like malloc(): largebin double linked list corrupted (bk).

@brandonwillard
Copy link
Member

brandonwillard commented May 14, 2022

Here's the Numba code generated by Aesara:

import inspect
from typing import Dict, Callable
from textwrap import dedent, indent

from aesara.compile.function.types import Function

from numba.np.ufunc.dufunc import DUFunc
from numba.core.dispatcher import Dispatcher


def get_numba_source(aesara_func: Function) -> Dict[str, Callable]:
    """Print the Aesara-generated Numba source code for a `Function` in a \"flat\" format.

    XXX: This code is definitely *not* usable (i.e. Numba `jit`-able) in its
    printed form, but it should be enough for debugging.

    """

    def _get_numba_src(fn, fn_src=None, indent_str=""):

        try:
            fn_globals = inspect.getclosurevars(fn).globals
        except TypeError:
            fn_globals = {}

        for name, obj in fn_globals.items():
            if isinstance(obj, DUFunc):
                # This is a vectorized function

                inner_fn = obj.py_scalar_func

                inner_src = dedent(inspect.getsource(inner_fn))
                inner_src = inner_src.replace(obj.__name__, name)

                outer_src = f"@numba.vectorize\n{inner_src}"

                # new_indent_str = indent_str + " " * 4
                new_indent_str = indent_str
                _get_numba_src(inner_fn, fn_src=outer_src, indent_str=new_indent_str)

            elif isinstance(obj, Dispatcher):
                py_func = obj.py_func

                if py_func.__name__ != name:
                    py_func_src = dedent(inspect.getsource(py_func)).replace(
                        py_func.__name__, name
                    )
                else:
                    py_func_src = None

                _get_numba_src(py_func, fn_src=py_func_src, indent_str=indent_str)
            else:
                print(indent(f"{name} = {getattr(obj, '__name__', obj)}\n", indent_str))

        try:
            if not fn_src:
                fn_src = dedent(inspect.getsource(fn))
            print(indent(fn_src, indent_str))
        except TypeError:
            return

    return _get_numba_src(aesara_func.vm.jit_fn.py_func)


get_numba_source(numba_power)
@numba.vectorize
def add(auto_22236, k):
    return auto_22236+k

auto_22236 = 1

np = numpy

@numba_njit(inline="always")
def shape_i(x):
    return np.shape(x)[i]

scalar_func = less

def less(auto_23317, auto_23318):
    return scalar_func(auto_23317, auto_23318)

@numba_basic.numba_njit(inline="always")
def switch(condition, x, y):
    if condition:
        return x
    else:
        return y

scalar_func = subtract

def subtract(auto_23316, auto_23323):
    return scalar_func(auto_23316, auto_23323)

def add(auto_23324, auto_23320):
    return auto_23324+auto_23320

scalar_func = maximum

def maximum(auto_23325, auto_23321):
    return scalar_func(auto_23325, auto_23321)

@numba.vectorize
def numba_funcified_fgraph1(auto_23316, auto_23317, auto_23318, auto_23319, auto_23320, auto_23321):
    # LT(<int64>, <int64>)
    auto_23322 = less(auto_23317, auto_23318)
    # Switch(LT.0, <int64>, <int64>)
    auto_23323 = switch(auto_23322, auto_23319, auto_23318)
    # sub(<int32>, Switch.0)
    auto_23324 = subtract(auto_23316, auto_23323)
    # add(sub.0, <int8>)
    auto_23325 = add(auto_23324, auto_23320)
    # maximum(add.0, <int8>)
    auto_23326 = maximum(auto_23325, auto_23321)
    return auto_23326

auto_22217 = 1

auto_22291 = 1

auto_22286 = 2

np = numpy

numba = numba

types = numba.core.types

TypingError = TypingError

@numba.generated_jit(nopython=True)
def to_scalar(x):
    if isinstance(x, (numba.types.Number, numba.types.Boolean)):
        return lambda x: x
    elif isinstance(x, numba.types.Array):
        return lambda x: x.item()
    else:
        raise TypingError(f"{x} must be a scalar compatible type.")

def alloc(val, auto_22217, auto_22179):
    val_np = np.asarray(val)
    auto_22217_item = to_scalar(auto_22217)
    auto_22179_item = to_scalar(auto_22179)
    scalar_shape = (auto_22217_item, auto_22179_item)
    res = np.empty(scalar_shape, dtype=val_np.dtype)
    res[...] = val_np
    return res

auto_22229 = [[1.]]

scalar_func = less

def less(auto_23114, auto_23115):
    return scalar_func(auto_23114, auto_23115)

@numba_basic.numba_njit(inline="always")
def switch(condition, x, y):
    if condition:
        return x
    else:
        return y

scalar_func = subtract

def subtract(auto_23121, auto_23117):
    return scalar_func(auto_23121, auto_23117)

scalar_func = subtract

def subtract1(auto_23122, auto_23118):
    return scalar_func(auto_23122, auto_23118)

def add(auto_23123, auto_23119):
    return auto_23123+auto_23119

@numba.vectorize
def numba_funcified_fgraph2(auto_23114, auto_23115, auto_23116, auto_23117, auto_23118, auto_23119):
    # LT(<int64>, <int64>)
    auto_23120 = less(auto_23114, auto_23115)
    # Switch(LT.0, <int64>, <int64>)
    auto_23121 = switch(auto_23120, auto_23116, auto_23115)
    # sub(Switch.0, <int32>)
    auto_23122 = subtract(auto_23121, auto_23117)
    # sub(sub.0, <int8>)
    auto_23123 = subtract1(auto_23122, auto_23118)
    # add(sub.0, <int64>)
    auto_23124 = add(auto_23123, auto_23119)
    return auto_23124

numba = numba

@numba_basic.numba_njit
def rebroadcast(x):
    for axis, value in numba.literal_unroll(op_axis):
        if value and x.shape[axis] != 1:
            raise ValueError(
                ("Dimension in Rebroadcast's input was supposed to be 1")
            )
    return x

np = numpy

scalar_func = subtract

def subtract(auto_23420, auto_23421):
    return scalar_func(auto_23420, auto_23421)

scalar_func = subtract

def subtract1(auto_23424, auto_23422):
    return scalar_func(auto_23424, auto_23422)

def add(auto_23425, auto_23423):
    return auto_23425+auto_23423

@numba.vectorize
def numba_funcified_fgraph(auto_23420, auto_23421, auto_23422, auto_23423):
    # sub(<int64>, <int32>)
    auto_23424 = subtract(auto_23420, auto_23421)
    # sub(sub.0, <int8>)
    auto_23425 = subtract1(auto_23424, auto_23422)
    # add(sub.0, <int64>)
    auto_23426 = add(auto_23425, auto_23423)
    return auto_23426

def numba_funcified_fgraph_inplace(str, str_1, str_2, str_3):
    str_scalar = np.asarray(str)
    return numba_funcified_fgraph(str, str_1, str_2, str_3, str_scalar).item()

np = numpy

scalar_func = subtract

def subtract(auto_23472, auto_23471):
    return scalar_func(auto_23472, auto_23471)

def add(auto_23472, auto_23471):
    return auto_23472+auto_23471

scalar_func = greater

def greater(auto_23471, auto_23472):
    return scalar_func(auto_23471, auto_23472)

@numba_basic.numba_njit(inline="always")
def switch(condition, x, y):
    if condition:
        return x
    else:
        return y

def add1(auto_23477, auto_23473):
    return auto_23477+auto_23473

@numba.vectorize
def numba_funcified_fgraph(auto_23471, auto_23472, auto_23473):
    # sub(<int64>, <int8>)
    auto_23474 = subtract(auto_23472, auto_23471)
    # add(<int64>, <int8>)
    auto_23475 = add(auto_23472, auto_23471)
    # GT(<int8>, <int64>)
    auto_23476 = greater(auto_23471, auto_23472)
    # Switch(GT.0, add.0, sub.0)
    auto_23477 = switch(auto_23476, auto_23475, auto_23474)
    # add(Switch.0, <int64>)
    auto_23478 = add1(auto_23477, auto_23473)
    return auto_23478

def numba_funcified_fgraph_inplace1(str, str_1, str_2):
    str_1_scalar = np.asarray(str_1)
    return numba_funcified_fgraph(str, str_1, str_2, str_1_scalar).item()

@numba_basic.numba_njit(inline="always")
def scalar_from_tensor(x):
    return x.item()

@numba_basic.numba_njit(inline="always")
def scalar_from_tensor1(x):
    return x.item()

numba = numba

types = numba.core.types

TypingError = TypingError

@numba.generated_jit(nopython=True)
def to_scalar(x):
    if isinstance(x, (numba.types.Number, numba.types.Boolean)):
        return lambda x: x
    elif isinstance(x, numba.types.Array):
        return lambda x: x.item()
    else:
        raise TypingError(f"{x} must be a scalar compatible type.")

np = numpy

dtype = float64

def allocempty(auto_23483, auto_22179):
    auto_23483_item = to_scalar(auto_23483)
    auto_22179_item = to_scalar(auto_22179)
    scalar_shape = (auto_23483_item, auto_22179_item)
    return np.empty(scalar_shape, dtype)

def incsubtensor(auto_22454, auto_22165, auto_22948):
    z = auto_22454
    indices = (slice(None, auto_22948, None),)
    z[indices] = auto_22165
    return z

auto_22948 = 1

@numba.vectorize
def mul(auto_23487, auto_23488):
    return auto_23487*auto_23488

def numba_at_inner_func(auto_23487, auto_23488):
    # Elemwise{mul,no_inplace}(*0-<TensorType(float64, (None,))>, *1-<TensorType(float64, (None,))>)
    auto_23498 = mul(auto_23487, auto_23488)
    return (auto_23498,)

def scan(n_steps, auto_23382, auto_22086):

    for i in range(n_steps):
        inner_args = (auto_23382[i], auto_22086)
        (auto_23382[i+1], ) = numba_at_inner_func(*inner_args)

    return auto_23382

def subtensor(auto_23496, auto_22498, auto_22499, auto_22761):

    indices = (slice(auto_22498, auto_22499, auto_22761),)
    z = auto_23496[indices]
    return z

auto_22761 = 1

def numba_funcified_fgraph(A, k):
    # Elemwise{add,no_inplace}(TensorConstant{1}, k)
    auto_22240 = add(auto_22236, k)
    # Shape_i{0}(A)
    auto_22179 = shape_i(A)
    # Elemwise{Composite{maximum(((i0 - Switch(LT(i1, i2), i3, i2)) + i4), i5)}}(k, TensorConstant{1}, Elemwise{add,no_inplace}.0, TensorConstant{1}, TensorConstant{1}, TensorConstant{2})
    auto_23334 = numba_funcified_fgraph1(k, auto_22217, auto_22240, auto_22217, auto_22291, auto_22286)
    # Alloc(TensorConstant{(1, 1) of 1.0}, TensorConstant{1}, Shape_i{0}.0)
    auto_22247 = alloc(auto_22229, auto_22217, auto_22179)
    # Elemwise{Composite{(((Switch(LT(i0, i1), i2, i1) - i3) - i4) + i5)}}(TensorConstant{1}, Elemwise{add,no_inplace}.0, TensorConstant{1}, k, TensorConstant{1}, Elemwise{Composite{maximum(((i0 - Switch(LT(i1, i2), i3, i2)) + i4), i5)}}.0)
    auto_23132 = numba_funcified_fgraph2(auto_22217, auto_22240, auto_22217, k, auto_22291, auto_23334)
    # Rebroadcast{(0, False)}(Alloc.0)
    auto_22165 = rebroadcast(auto_22247)
    # Elemwise{Composite{(((i0 - i1) - i2) + i3)}}[(0, 0)](Elemwise{add,no_inplace}.0, k, TensorConstant{1}, Elemwise{Composite{maximum(((i0 - Switch(LT(i1, i2), i3, i2)) + i4), i5)}}.0)
    auto_23432 = numba_funcified_fgraph_inplace(auto_22240, k, auto_22291, auto_23334)
    # Elemwise{Composite{(Switch(GT(i0, i1), (i1 + i0), (i1 - i0)) + i2)}}[(0, 1)](TensorConstant{1}, Elemwise{Composite{maximum(((i0 - Switch(LT(i1, i2), i3, i2)) + i4), i5)}}.0, TensorConstant{1})
    auto_23483 = numba_funcified_fgraph_inplace1(auto_22291, auto_23334, auto_22217)
    # ScalarFromTensor(Elemwise{Composite{(((Switch(LT(i0, i1), i2, i1) - i3) - i4) + i5)}}.0)
    auto_22498 = scalar_from_tensor(auto_23132)
    # ScalarFromTensor(Elemwise{Composite{(((i0 - i1) - i2) + i3)}}[(0, 0)].0)
    auto_22499 = scalar_from_tensor1(auto_23432)
    # AllocEmpty{dtype='float64'}(Elemwise{Composite{(Switch(GT(i0, i1), (i1 + i0), (i1 - i0)) + i2)}}[(0, 1)].0, Shape_i{0}.0)
    auto_22454 = allocempty(auto_23483, auto_22179)
    # IncSubtensor{InplaceSet;:int64:}(AllocEmpty{dtype='float64'}.0, Rebroadcast{(0, False)}.0, ScalarConstant{1})
    auto_23382 = incsubtensor(auto_22454, auto_22165, auto_22948)
    # forall_inplace,cpu,scan_fn}(k, IncSubtensor{InplaceSet;:int64:}.0, A)
    auto_23496 = scan(k, auto_23382, A)
    # Subtensor{int64:int64:int8}(forall_inplace,cpu,scan_fn}.0, ScalarFromTensor.0, ScalarFromTensor.0, ScalarConstant{1})
    auto_22762 = subtensor(auto_23496, auto_22498, auto_22499, auto_22761)
    return (auto_22762,)

@brandonwillard
Copy link
Member

brandonwillard commented May 16, 2022

Looks like the rewrite save_mem_new_scan is effectively shortening the input array to the Scan Op (to save memory), but the Numba implementation is written assuming that the entire array is present in full length so that sit-sot (i.e. lag one input/output relationships) results can be computed from the indices i and stored at index i + 1—instead of i - 1 and i. This explains the low-level crashes I've been seeing, because we're writing the results outside of the allocated array ranges.

Here's a quick run through of the inputs given to the Scan node in both the C and Numba versions of the compiled function:

import numpy as np

import aesara
import aesara.tensor as at


k = at.iscalar("k")
A = at.vector("A")

result, _ = aesara.scan(
    fn=lambda prior_result, A: prior_result * A,
    outputs_info=at.ones_like(A),
    non_sequences=A,
    n_steps=k,
)

power = aesara.function(inputs=[A, k], outputs=result)

power(np.arange(10), 2)
# array([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
#        [ 0.,  1.,  4.,  9., 16., 25., 36., 49., 64., 81.]])

scan_inputs = power.maker.fgraph.outputs[0].owner.inputs[0].owner.inputs
numba_scan_inputs_fn = aesara.function(inputs=[A, k], outputs=scan_inputs, mode="NUMBA", on_unused_input="ignore", accept_inplace=True)

scan_inputs_vals = numba_scan_inputs_fn(np.arange(10), 2)
scan_inputs_vals
# [array(2, dtype=int32),
#  array([[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
#           1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
#           1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
#           1.00000000e+00],
#         [-1.36311572e+57, -1.36311572e+57, -1.36311572e+57,
#          -1.36311572e+57, -1.36311572e+57, -1.36311572e+57,
#          -1.36311572e+57, -1.36311572e+57, -1.36311572e+57,
#          -1.36311572e+57]]),
#  array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])]

scan_inputs_vals[1].shape
# (2, 10)

As we can see, the array used as storage by the Scan node has a length of two on the iterated dimension, so this is what the Numba implementation would do if it was performed in pure Python:

def scan_test(n_steps, auto_23382, auto_22086):

    for i in range(n_steps):
        inner_args = (auto_23382[i], auto_22086)
        auto_23382[i + 1] = np.multiply(*inner_args)

    return auto_23382


scan_test(*scan_inputs_vals)
# IndexError: index 2 is out of bounds for axis 0 with size 2

Apparently, Numba doesn't perform the bounds check and silently corrupts the session instead.

Compiling without that Scan optimization avoids the issue:

from aesara.compile.mode import get_mode


mode = get_mode("NUMBA").excluding("scan_save_mem")

numba_power_fn = aesara.function(inputs=[A, k], outputs=result, mode=mode)

numba_power_fn(np.arange(10), 2)
# array([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
#        [ 0.,  1.,  4.,  9., 16., 25., 36., 49., 64., 81.]])

@brandonwillard brandonwillard changed the title Numba backend gives wrong results with Scan documentation example Numba Scan fails when sit-sot sequences aren't n_steps in length May 16, 2022
@brandonwillard brandonwillard changed the title Numba Scan fails when sit-sot sequences aren't n_steps in length Numba Scan fails when sit-sot sequences aren't full length May 16, 2022
@kc611
Copy link
Contributor

kc611 commented Jun 4, 2022

Apparently, Numba doesn't perform the bounds check and silently corrupts the session instead.

I ran into this same issue some time ago. See: numba/numba#8127

However, performing that boundscheck will probably have an impact on performance, so it's a trade-off. Surely the solution here can be instead having a mechanism for Numba back-end to exclude certain low level optimizations (such as those dealing with memory optimization) so that we can hand-off that responsibility to the Numba/LLVM framework.

@brandonwillard
Copy link
Member

brandonwillard commented Jun 4, 2022

It looks like this is ultimately yet another special Scan condition that we need to add to our Numba implementation. After refactoring the Cython implementation, I think I know exactly where/when/how this special case is handled.

In other words, Numba isn't doing anything wrong and the Numba implementation isn't either. This is just an old Theano optimization that attempts to reduce the amount of intermediate storage space used by the Op. The end result is that our current assumption of full length output storage really doesn't hold.

I ran into this same issue some time ago. See: numba/numba#8127

You're right, though; there do seem to be some discrepancies in Numba's bounds checking behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed important Numba Involves Numba transpilation Scan Involves the `Scan` `Op`
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants