From bd3144926eb006d4a730dd6802495f7003e1dbb5 Mon Sep 17 00:00:00 2001 From: khaled Date: Thu, 4 Jan 2024 18:35:50 -0600 Subject: [PATCH 1/8] Adding patch_basic_indexing --- numba_dpex/__init__.py | 8 +- .../numba_patches/patch_basic_indexing.py | 92 +++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 numba_dpex/numba_patches/patch_basic_indexing.py diff --git a/numba_dpex/__init__.py b/numba_dpex/__init__.py index cf72ed70d1..a5daee9d73 100644 --- a/numba_dpex/__init__.py +++ b/numba_dpex/__init__.py @@ -20,7 +20,11 @@ from numba_dpex.core.kernel_interface.launcher import call_kernel from numba_dpex.vectorizers import Vectorize as DpexVectorize -from .numba_patches import patch_arrayexpr_tree_to_ir, patch_is_ufunc +from .numba_patches import ( + patch_arrayexpr_tree_to_ir, + patch_basic_indexing, + patch_is_ufunc, +) def load_dpctl_sycl_interface(): @@ -81,7 +85,9 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]: from .numba_patches import patch_mk_alloc patch_mk_alloc.patch() + patch_arrayexpr_tree_to_ir.patch() +patch_basic_indexing.patch() dpctl_sem_version = parse_sem_version(dpctl.__version__) if dpctl_sem_version < (0, 14): diff --git a/numba_dpex/numba_patches/patch_basic_indexing.py b/numba_dpex/numba_patches/patch_basic_indexing.py new file mode 100644 index 0000000000..44b681b15f --- /dev/null +++ b/numba_dpex/numba_patches/patch_basic_indexing.py @@ -0,0 +1,92 @@ +def patch(): + from numba.core import cgutils, types + from numba.cpython import slicing + from numba.np import arrayobj + from numba.np.arrayobj import fix_integer_index + from numba.np.numpy_support import is_nonelike + + # ------------------------------------------------------------------------- + # Basic indexing (with integers and slices only) + + def basic_indexing( + context, builder, aryty, ary, index_types, indices, boundscheck=None + ): + """ + Perform basic indexing on the given array. + A (data pointer, shapes, strides) tuple is returned describing + the corresponding view. + """ + print("============> doing this basic_indexing") + + zero = context.get_constant(types.intp, 0) + one = context.get_constant(types.intp, 1) + + shapes = cgutils.unpack_tuple(builder, ary.shape, aryty.ndim) + strides = cgutils.unpack_tuple(builder, ary.strides, aryty.ndim) + + output_indices = [] + output_shapes = [] + output_strides = [] + + num_newaxes = len([idx for idx in index_types if is_nonelike(idx)]) + ax = 0 + for indexval, idxty in zip(indices, index_types): + if idxty is types.ellipsis: + # Fill up missing dimensions at the middle + n_missing = aryty.ndim - len(indices) + 1 + num_newaxes + for i in range(n_missing): + output_indices.append(zero) + output_shapes.append(shapes[ax]) + output_strides.append(strides[ax]) + ax += 1 + continue + # Regular index value + if isinstance(idxty, types.SliceType): + slice = context.make_helper(builder, idxty, value=indexval) + slicing.guard_invalid_slice(context, builder, idxty, slice) + slicing.fix_slice(builder, slice, shapes[ax]) + output_indices.append(slice.start) + sh = slicing.get_slice_length(builder, slice) + st = slicing.fix_stride(builder, slice, strides[ax]) + output_shapes.append(sh) + output_strides.append(st) + elif isinstance(idxty, types.Integer): + ind = fix_integer_index( + context, builder, idxty, indexval, shapes[ax] + ) + if boundscheck: + cgutils.do_boundscheck( + context, builder, ind, shapes[ax], ax + ) + output_indices.append(ind) + elif is_nonelike(idxty): + output_shapes.append(one) + output_strides.append(zero) + ax -= 1 + else: + raise NotImplementedError( + "unexpected index type: %s" % (idxty,) + ) + ax += 1 + + # Fill up missing dimensions at the end + assert ax <= aryty.ndim + while ax < aryty.ndim: + output_shapes.append(shapes[ax]) + output_strides.append(strides[ax]) + ax += 1 + + # No need to check wraparound, as negative indices were already + # fixed in the loop above. + dataptr = cgutils.get_item_pointer( + context, + builder, + aryty, + ary, + output_indices, + wraparound=False, + boundscheck=False, + ) + return (dataptr, output_shapes, output_strides) + + arrayobj.basic_indexing = basic_indexing From 77bce4e433900d06d8b9a5cb9187d9ae6dc2f439 Mon Sep 17 00:00:00 2001 From: khaled Date: Thu, 4 Jan 2024 19:13:12 -0600 Subject: [PATCH 2/8] IR for mulitplying with itemsize --- numba_dpex/numba_patches/patch_basic_indexing.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/numba_dpex/numba_patches/patch_basic_indexing.py b/numba_dpex/numba_patches/patch_basic_indexing.py index 44b681b15f..e0becb7261 100644 --- a/numba_dpex/numba_patches/patch_basic_indexing.py +++ b/numba_dpex/numba_patches/patch_basic_indexing.py @@ -5,6 +5,8 @@ def patch(): from numba.np.arrayobj import fix_integer_index from numba.np.numpy_support import is_nonelike + from numba_dpex.core.types import DpnpNdArray + # ------------------------------------------------------------------------- # Basic indexing (with integers and slices only) @@ -24,6 +26,9 @@ def basic_indexing( shapes = cgutils.unpack_tuple(builder, ary.shape, aryty.ndim) strides = cgutils.unpack_tuple(builder, ary.strides, aryty.ndim) + if isinstance(aryty, DpnpNdArray): + print(f"multiply each stride value with itemsize = {ary.itemsize}") + output_indices = [] output_shapes = [] output_strides = [] From 442c6fd533a0da61d4fcf1d2be98b780b895a02f Mon Sep 17 00:00:00 2001 From: khaled Date: Fri, 5 Jan 2024 12:07:40 -0600 Subject: [PATCH 3/8] Add license --- numba_dpex/numba_patches/patch_basic_indexing.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/numba_dpex/numba_patches/patch_basic_indexing.py b/numba_dpex/numba_patches/patch_basic_indexing.py index e0becb7261..32b656a03d 100644 --- a/numba_dpex/numba_patches/patch_basic_indexing.py +++ b/numba_dpex/numba_patches/patch_basic_indexing.py @@ -1,3 +1,8 @@ +# SPDX-FileCopyrightText: 2023 Intel Corporation +# +# SPDX-License-Identifier: Apache-2.0 + + def patch(): from numba.core import cgutils, types from numba.cpython import slicing @@ -28,6 +33,10 @@ def basic_indexing( if isinstance(aryty, DpnpNdArray): print(f"multiply each stride value with itemsize = {ary.itemsize}") + for i in range(len(strides)): + u = strides[i] + v = builder.mul(u, ary.itemsize) + strides[i] = v output_indices = [] output_shapes = [] From b9959ad5a08c8b72a74482e7a2cda05e307f5798 Mon Sep 17 00:00:00 2001 From: khaled Date: Fri, 5 Jan 2024 12:54:09 -0600 Subject: [PATCH 4/8] Monkeypatch get_item_pointer through basic_indexing --- .../numba_patches/patch_basic_indexing.py | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/numba_dpex/numba_patches/patch_basic_indexing.py b/numba_dpex/numba_patches/patch_basic_indexing.py index 32b656a03d..74016a34f2 100644 --- a/numba_dpex/numba_patches/patch_basic_indexing.py +++ b/numba_dpex/numba_patches/patch_basic_indexing.py @@ -12,6 +12,30 @@ def patch(): from numba_dpex.core.types import DpnpNdArray + def get_item_pointer( + context, builder, aryty, ary, inds, wraparound=False, boundscheck=False + ): + # Set boundscheck=True for any pointer access that should be + # boundschecked. do_boundscheck() will handle enabling or disabling the + # actual boundschecking based on the user config. + shapes = cgutils.unpack_tuple(builder, ary.shape, count=aryty.ndim) + strides = cgutils.unpack_tuple(builder, ary.strides, count=aryty.ndim) + + if isinstance(aryty, DpnpNdArray): + print("==========> we are doing this get_item_pointer") + + return cgutils.get_item_pointer2( + context, + builder, + data=ary.data, + shape=shapes, + strides=strides, + layout=aryty.layout, + inds=inds, + wraparound=wraparound, + boundscheck=boundscheck, + ) + # ------------------------------------------------------------------------- # Basic indexing (with integers and slices only) @@ -23,7 +47,6 @@ def basic_indexing( A (data pointer, shapes, strides) tuple is returned describing the corresponding view. """ - print("============> doing this basic_indexing") zero = context.get_constant(types.intp, 0) one = context.get_constant(types.intp, 1) @@ -31,13 +54,6 @@ def basic_indexing( shapes = cgutils.unpack_tuple(builder, ary.shape, aryty.ndim) strides = cgutils.unpack_tuple(builder, ary.strides, aryty.ndim) - if isinstance(aryty, DpnpNdArray): - print(f"multiply each stride value with itemsize = {ary.itemsize}") - for i in range(len(strides)): - u = strides[i] - v = builder.mul(u, ary.itemsize) - strides[i] = v - output_indices = [] output_shapes = [] output_strides = [] @@ -92,7 +108,7 @@ def basic_indexing( # No need to check wraparound, as negative indices were already # fixed in the loop above. - dataptr = cgutils.get_item_pointer( + dataptr = get_item_pointer( context, builder, aryty, From e2b06f3ddbf562d8fd5456f5867ebcb1dbcbda8b Mon Sep 17 00:00:00 2001 From: khaled Date: Fri, 5 Jan 2024 12:58:52 -0600 Subject: [PATCH 5/8] Added docs to the patch --- numba_dpex/numba_patches/patch_basic_indexing.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/numba_dpex/numba_patches/patch_basic_indexing.py b/numba_dpex/numba_patches/patch_basic_indexing.py index 74016a34f2..a670769bbb 100644 --- a/numba_dpex/numba_patches/patch_basic_indexing.py +++ b/numba_dpex/numba_patches/patch_basic_indexing.py @@ -4,6 +4,14 @@ def patch(): + """Patches the basic_indexing function from numba.np.arrayobj + + Raises: + NotImplementedError: If basic_indexing fails. Refer to the function. + + Returns: + tuple: Whatever is returned by the inner basic_indexing function. + """ from numba.core import cgutils, types from numba.cpython import slicing from numba.np import arrayobj @@ -22,7 +30,8 @@ def get_item_pointer( strides = cgutils.unpack_tuple(builder, ary.strides, count=aryty.ndim) if isinstance(aryty, DpnpNdArray): - print("==========> we are doing this get_item_pointer") + for i in range(len(strides)): + strides[i] = builder.mul(strides[i], ary.itemsize) return cgutils.get_item_pointer2( context, From 60ef735a7bd0699cb446dd581f03f78326004790 Mon Sep 17 00:00:00 2001 From: khaled Date: Fri, 5 Jan 2024 13:07:58 -0600 Subject: [PATCH 6/8] Check on context type --- numba_dpex/numba_patches/patch_basic_indexing.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numba_dpex/numba_patches/patch_basic_indexing.py b/numba_dpex/numba_patches/patch_basic_indexing.py index a670769bbb..bbe8e8bb1d 100644 --- a/numba_dpex/numba_patches/patch_basic_indexing.py +++ b/numba_dpex/numba_patches/patch_basic_indexing.py @@ -18,6 +18,7 @@ def patch(): from numba.np.arrayobj import fix_integer_index from numba.np.numpy_support import is_nonelike + from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext from numba_dpex.core.types import DpnpNdArray def get_item_pointer( @@ -29,7 +30,9 @@ def get_item_pointer( shapes = cgutils.unpack_tuple(builder, ary.shape, count=aryty.ndim) strides = cgutils.unpack_tuple(builder, ary.strides, count=aryty.ndim) - if isinstance(aryty, DpnpNdArray): + if isinstance(aryty, DpnpNdArray) and isinstance( + context, DpexKernelTargetContext + ): for i in range(len(strides)): strides[i] = builder.mul(strides[i], ary.itemsize) From d2036d3f98e55e391adda105d2cf5d2d68c90ad8 Mon Sep 17 00:00:00 2001 From: khaled Date: Mon, 8 Jan 2024 16:52:57 -0600 Subject: [PATCH 7/8] wip --- .../numba_patches/patch_basic_indexing.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/numba_dpex/numba_patches/patch_basic_indexing.py b/numba_dpex/numba_patches/patch_basic_indexing.py index bbe8e8bb1d..00323c334b 100644 --- a/numba_dpex/numba_patches/patch_basic_indexing.py +++ b/numba_dpex/numba_patches/patch_basic_indexing.py @@ -18,7 +18,10 @@ def patch(): from numba.np.arrayobj import fix_integer_index from numba.np.numpy_support import is_nonelike - from numba_dpex.core.targets.kernel_target import DpexKernelTargetContext + from numba_dpex.core.targets.kernel_target import ( + DpexKernelTargetContext, + DpexKernelTypingContext, + ) from numba_dpex.core.types import DpnpNdArray def get_item_pointer( @@ -30,11 +33,16 @@ def get_item_pointer( shapes = cgutils.unpack_tuple(builder, ary.shape, count=aryty.ndim) strides = cgutils.unpack_tuple(builder, ary.strides, count=aryty.ndim) - if isinstance(aryty, DpnpNdArray) and isinstance( - context, DpexKernelTargetContext - ): - for i in range(len(strides)): - strides[i] = builder.mul(strides[i], ary.itemsize) + # if ( + # isinstance(aryty, DpnpNdArray) # noqa: E800 + # and isinstance(context, DpexKernelTargetContext) + # and isinstance(context.typing_context, DpexKernelTypingContext) + # ): + # print("==========> doing this") # noqa: E800 + # for i in range(len(strides)): + # strides[i] = builder.mul(strides[i], ary.itemsize) # noqa: E800 + # else: # noqa: E800 + # print("=========> not doing this") # noqa: E800 return cgutils.get_item_pointer2( context, From b31d4a258f068596bf511ec03126148ec36b1cf2 Mon Sep 17 00:00:00 2001 From: khaled Date: Mon, 8 Jan 2024 17:14:07 -0600 Subject: [PATCH 8/8] wip --- .../numba_patches/patch_basic_indexing.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/numba_dpex/numba_patches/patch_basic_indexing.py b/numba_dpex/numba_patches/patch_basic_indexing.py index 00323c334b..f4c51952d7 100644 --- a/numba_dpex/numba_patches/patch_basic_indexing.py +++ b/numba_dpex/numba_patches/patch_basic_indexing.py @@ -33,16 +33,16 @@ def get_item_pointer( shapes = cgutils.unpack_tuple(builder, ary.shape, count=aryty.ndim) strides = cgutils.unpack_tuple(builder, ary.strides, count=aryty.ndim) - # if ( - # isinstance(aryty, DpnpNdArray) # noqa: E800 - # and isinstance(context, DpexKernelTargetContext) - # and isinstance(context.typing_context, DpexKernelTypingContext) - # ): - # print("==========> doing this") # noqa: E800 - # for i in range(len(strides)): - # strides[i] = builder.mul(strides[i], ary.itemsize) # noqa: E800 - # else: # noqa: E800 - # print("=========> not doing this") # noqa: E800 + if ( + isinstance(aryty, DpnpNdArray) # noqa: E800 + and isinstance(context, DpexKernelTargetContext) + and isinstance(context.typing_context, DpexKernelTypingContext) + ): + print("==========> doing this") # noqa: E800 + for i in range(len(strides)): + strides[i] = builder.mul(strides[i], ary.itemsize) # noqa: E800 + else: # noqa: E800 + print("=========> not doing this") # noqa: E800 return cgutils.get_item_pointer2( context,