Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion numba_dpex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from numba_dpex.core.kernel_interface.launcher import call_kernel
from numba_dpex.vectorizers import Vectorize as DpexVectorize

from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc
from .numba_patches import (
patch_arrayexpr_tree_to_ir,
patch_basic_indexing,
patch_is_ufunc,
)


def load_dpctl_sycl_interface():
Expand Down Expand Up @@ -81,7 +85,9 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
from .numba_patches import patch_mk_alloc

patch_mk_alloc.patch()

patch_arrayexpr_tree_to_ir.patch()
patch_basic_indexing.patch()

dpctl_sem_version = parse_sem_version(dpctl.__version__)
if dpctl_sem_version < (0, 14):
Expand Down
142 changes: 142 additions & 0 deletions numba_dpex/numba_patches/patch_basic_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# SPDX-FileCopyrightText: 2023 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0


def patch():
"""Patches the basic_indexing function from numba.np.arrayobj

Raises:
NotImplementedError: If basic_indexing fails. Refer to the function.

Returns:
tuple: Whatever is returned by the inner basic_indexing function.
"""
from numba.core import cgutils, types
from numba.cpython import slicing
from numba.np import arrayobj
from numba.np.arrayobj import fix_integer_index
from numba.np.numpy_support import is_nonelike

from numba_dpex.core.targets.kernel_target import (
DpexKernelTargetContext,
DpexKernelTypingContext,
)
from numba_dpex.core.types import DpnpNdArray

def get_item_pointer(
context, builder, aryty, ary, inds, wraparound=False, boundscheck=False
):
# Set boundscheck=True for any pointer access that should be
# boundschecked. do_boundscheck() will handle enabling or disabling the
# actual boundschecking based on the user config.
shapes = cgutils.unpack_tuple(builder, ary.shape, count=aryty.ndim)
strides = cgutils.unpack_tuple(builder, ary.strides, count=aryty.ndim)

if (
isinstance(aryty, DpnpNdArray) # noqa: E800
and isinstance(context, DpexKernelTargetContext)
and isinstance(context.typing_context, DpexKernelTypingContext)
):
print("==========> doing this") # noqa: E800
for i in range(len(strides)):
strides[i] = builder.mul(strides[i], ary.itemsize) # noqa: E800
else: # noqa: E800
print("=========> not doing this") # noqa: E800

return cgutils.get_item_pointer2(
context,
builder,
data=ary.data,
shape=shapes,
strides=strides,
layout=aryty.layout,
inds=inds,
wraparound=wraparound,
boundscheck=boundscheck,
)

# -------------------------------------------------------------------------
# Basic indexing (with integers and slices only)

def basic_indexing(
context, builder, aryty, ary, index_types, indices, boundscheck=None
):
"""
Perform basic indexing on the given array.
A (data pointer, shapes, strides) tuple is returned describing
the corresponding view.
"""

zero = context.get_constant(types.intp, 0)
one = context.get_constant(types.intp, 1)

shapes = cgutils.unpack_tuple(builder, ary.shape, aryty.ndim)
strides = cgutils.unpack_tuple(builder, ary.strides, aryty.ndim)

output_indices = []
output_shapes = []
output_strides = []

num_newaxes = len([idx for idx in index_types if is_nonelike(idx)])
ax = 0
for indexval, idxty in zip(indices, index_types):
if idxty is types.ellipsis:
# Fill up missing dimensions at the middle
n_missing = aryty.ndim - len(indices) + 1 + num_newaxes
for i in range(n_missing):
output_indices.append(zero)
output_shapes.append(shapes[ax])
output_strides.append(strides[ax])
ax += 1
continue
# Regular index value
if isinstance(idxty, types.SliceType):
slice = context.make_helper(builder, idxty, value=indexval)
slicing.guard_invalid_slice(context, builder, idxty, slice)
slicing.fix_slice(builder, slice, shapes[ax])
output_indices.append(slice.start)
sh = slicing.get_slice_length(builder, slice)
st = slicing.fix_stride(builder, slice, strides[ax])
output_shapes.append(sh)
output_strides.append(st)
elif isinstance(idxty, types.Integer):
ind = fix_integer_index(
context, builder, idxty, indexval, shapes[ax]
)
if boundscheck:
cgutils.do_boundscheck(
context, builder, ind, shapes[ax], ax
)
output_indices.append(ind)
elif is_nonelike(idxty):
output_shapes.append(one)
output_strides.append(zero)
ax -= 1
else:
raise NotImplementedError(
"unexpected index type: %s" % (idxty,)
)
ax += 1

# Fill up missing dimensions at the end
assert ax <= aryty.ndim
while ax < aryty.ndim:
output_shapes.append(shapes[ax])
output_strides.append(strides[ax])
ax += 1

# No need to check wraparound, as negative indices were already
# fixed in the loop above.
dataptr = get_item_pointer(
context,
builder,
aryty,
ary,
output_indices,
wraparound=False,
boundscheck=False,
)
return (dataptr, output_shapes, output_strides)

arrayobj.basic_indexing = basic_indexing