Skip to content

Commit

Permalink
Merge pull request #1680 from IntelPython/fix-basic-slicing-on-empty-…
Browse files Browse the repository at this point in the history
…arrays

Fix bug in basic slicing of empty arrays
  • Loading branch information
oleksandr-pavlyk committed May 16, 2024
2 parents ba09dd8 + 9fa1aec commit f3d8ee7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
16 changes: 11 additions & 5 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
sh0 = _slice_len(sl_start, sl_stop, sl_step)
str0 = sl_step * strides[0]
new_strides = strides if (sl_step == 1 or sh0 == 0) else (str0,) + strides[1:]
new_offset = offset if sh0 == 0 else offset + sl_start * strides[0]
new_shape = (sh0, ) + shape[1:]
is_empty = any(sh_i == 0 for sh_i in new_shape)
new_offset = offset if is_empty else offset + sl_start * strides[0]
return (
(sh0, ) + shape[1:],
new_shape,
new_strides,
new_offset,
_no_advanced_ind,
Expand All @@ -135,11 +137,15 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
return ((0,) + shape, (0,) + strides, offset, _no_advanced_ind, _no_advanced_pos)
elif _is_integral(ind):
ind = ind.__index__()
new_shape = shape[1:]
new_strides = strides[1:]
is_empty = any(sh_i == 0 for sh_i in new_shape)
if 0 <= ind < shape[0]:
return (shape[1:], strides[1:], offset + ind * strides[0], _no_advanced_ind, _no_advanced_pos)
new_offset = offset if is_empty else offset + ind * strides[0]
return (new_shape, new_strides, new_offset, _no_advanced_ind, _no_advanced_pos)
elif -shape[0] <= ind < 0:
return (shape[1:], strides[1:],
offset + (shape[0] + ind) * strides[0], _no_advanced_ind, _no_advanced_pos)
new_offset = offset if is_empty else offset + (shape[0] + ind) * strides[0]
return (new_shape, new_strides, new_offset, _no_advanced_ind, _no_advanced_pos)
else:
raise IndexError(
"Index {0} is out of range for axes 0 with "
Expand Down
24 changes: 24 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,30 @@ def test_slicing_basic():
assert np.array_equal(Xh, Xnp[Xnp[2] : Xnp[5]])


def test_slicing_empty():
try:
X = dpt.usm_ndarray((0, 10), dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
x = dpt.moveaxis(X, 1, 0)
# this used to raise ValueError
y = x[1]
assert y.ndim == 1
assert y.shape == (0,)
assert y.dtype == X.dtype
assert y.usm_type == X.usm_type
assert y.sycl_queue == X.sycl_queue
w = x[1:3]
assert w.ndim == 2
assert w.shape == (
2,
0,
)
assert w.dtype == X.dtype
assert w.usm_type == X.usm_type
assert w.sycl_queue == X.sycl_queue


def test_ctor_invalid_shape():
with pytest.raises(TypeError):
dpt.usm_ndarray(dict())
Expand Down

0 comments on commit f3d8ee7

Please sign in to comment.