Skip to content

Commit

Permalink
Merge pull request #1677 from IntelPython/reshape-improvements
Browse files Browse the repository at this point in the history
Reshape improvements
  • Loading branch information
oleksandr-pavlyk committed May 16, 2024
2 parents d840cee + 9d2633f commit c994666
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
22 changes: 13 additions & 9 deletions dpctl/tensor/_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np

import dpctl.tensor as dpt
from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
from dpctl.tensor._tensor_impl import (
_copy_usm_ndarray_for_reshape,
_ravel_multi_index,
Expand Down Expand Up @@ -155,32 +154,37 @@ def reshape(X, /, shape, *, order="C", copy=None):
"Reshaping the array requires a copy, but no copying was "
"requested by using copy=False"
)
copy_q = X.sycl_queue
if copy_required or (copy is True):
# must perform a copy
flat_res = dpt.usm_ndarray(
(X.size,),
dtype=X.dtype,
buffer=X.usm_type,
buffer_ctor_kwargs={"queue": X.sycl_queue},
buffer_ctor_kwargs={"queue": copy_q},
)
if order == "C":
hev, _ = _copy_usm_ndarray_for_reshape(
src=X, dst=flat_res, sycl_queue=X.sycl_queue
src=X, dst=flat_res, sycl_queue=copy_q
)
hev.wait()
else:
for i in range(X.size):
_copy_from_usm_ndarray_to_usm_ndarray(
flat_res[i], X[np.unravel_index(i, X.shape, order=order)]
)
X_t = dpt.permute_dims(X, range(X.ndim - 1, -1, -1))
hev, _ = _copy_usm_ndarray_for_reshape(
src=X_t, dst=flat_res, sycl_queue=copy_q
)
hev.wait()
return dpt.usm_ndarray(
tuple(shape), dtype=X.dtype, buffer=flat_res, order=order
)
# can form a view
if (len(shape) == X.ndim) and all(
s1 == s2 for s1, s2 in zip(shape, X.shape)
):
return X
return dpt.usm_ndarray(
shape,
dtype=X.dtype,
buffer=X,
strides=tuple(newsts),
offset=X.__sycl_usm_array_interface__.get("offset", 0),
offset=X._element_offset,
)
27 changes: 27 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,33 @@ def test_reshape():
assert A4.shape == requested_shape


def test_reshape_orderF():
try:
a = dpt.arange(6 * 3 * 4, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
b = dpt.reshape(a, (6, 2, 6))
c = dpt.reshape(b, (9, 8), order="F")
assert c.flags.f_contiguous
assert c._pointer != b._pointer
assert b._pointer == a._pointer

a_np = np.arange(6 * 3 * 4, dtype="i4")
b_np = np.reshape(a_np, (6, 2, 6))
c_np = np.reshape(b_np, (9, 8), order="F")
assert np.array_equal(c_np, dpt.asnumpy(c))


def test_reshape_noop():
"""Per gh-1664"""
try:
a = dpt.ones((2, 1))
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
b = dpt.reshape(a, (2, 1))
assert b is a


def test_reshape_zero_size():
try:
a = dpt.empty((0,))
Expand Down

0 comments on commit c994666

Please sign in to comment.