diff --git a/dpctl/tensor/_reshape.py b/dpctl/tensor/_reshape.py index 99726af4d3..ac4a04cac4 100644 --- a/dpctl/tensor/_reshape.py +++ b/dpctl/tensor/_reshape.py @@ -19,7 +19,11 @@ import dpctl.tensor as dpt from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray -from dpctl.tensor._tensor_impl import _copy_usm_ndarray_for_reshape +from dpctl.tensor._tensor_impl import ( + _copy_usm_ndarray_for_reshape, + _ravel_multi_index, + _unravel_index, +) __doc__ = "Implementation module for :func:`dpctl.tensor.reshape`." @@ -36,6 +40,14 @@ def _make_unit_indexes(shape): return mi +def ti_unravel_index(flat_index, shape, order="C"): + return _unravel_index(flat_index, shape, order) + + +def ti_ravel_multi_index(multi_index, shape, order="C"): + return _ravel_multi_index(multi_index, shape, order) + + def reshaped_strides(old_sh, old_sts, new_sh, order="C"): """ When reshaping array with `old_sh` shape and `old_sts` strides @@ -47,11 +59,11 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"): sum( st_i * ind_i for st_i, ind_i in zip( - old_sts, np.unravel_index(flat_index, old_sh, order=order) + old_sts, ti_unravel_index(flat_index, old_sh, order=order) ) ) for flat_index in [ - np.ravel_multi_index(unitvec, new_sh, order=order) + ti_ravel_multi_index(unitvec, new_sh, order=order) for unitvec in eye_new_mi ] ] @@ -60,11 +72,11 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"): sum( st_i * ind_i for st_i, ind_i in zip( - new_sts, np.unravel_index(flat_index, new_sh, order=order) + new_sts, ti_unravel_index(flat_index, new_sh, order=order) ) ) for flat_index in [ - np.ravel_multi_index(unitvec, old_sh, order=order) + ti_ravel_multi_index(unitvec, old_sh, order=order) for unitvec in eye_old_mi ] ] @@ -123,7 +135,13 @@ def reshape(X, shape, order="C", copy=None): "value which can only be -1" ) if negative_ones_count: - v = X.size // (-np.prod(shape)) + sz = -np.prod(shape) + if sz == 0: + raise ValueError( + f"Can not reshape array of size {X.size} into " + f"shape {tuple(i for i in shape if i >= 0)}" + ) + v = X.size // sz shape = [v if d == -1 else d for d in shape] if X.size != np.prod(shape): raise ValueError(f"Can not reshape into {shape}") diff --git a/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp b/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp index e11495204a..90b88bcd0e 100644 --- a/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp +++ b/dpctl/tensor/libtensor/source/simplify_iteration_space.cpp @@ -71,12 +71,17 @@ void simplify_iteration_space_1(int &nd, nd = contracted_nd; } else if (nd == 1) { + offset = 0; // Populate vectors simplified_shape.reserve(nd); simplified_shape.push_back(shape[0]); simplified_strides.reserve(nd); - simplified_strides.push_back(strides[0]); + simplified_strides.push_back((strides[0] >= 0) ? strides[0] + : -strides[0]); + if ((strides[0] < 0) && (shape[0] > 1)) { + offset += (shape[0] - 1) * strides[0]; + } assert(simplified_shape.size() == static_cast(nd)); assert(simplified_strides.size() == static_cast(nd)); @@ -128,17 +133,27 @@ void simplify_iteration_space(int &nd, nd = contracted_nd; } else if (nd == 1) { + src_offset = 0; + dst_offset = 0; // Populate vectors simplified_shape.reserve(nd); simplified_shape.push_back(shape[0]); assert(simplified_shape.size() == static_cast(nd)); simplified_src_strides.reserve(nd); - simplified_src_strides.push_back(src_strides[0]); + simplified_src_strides.push_back( + (src_strides[0] >= 0) ? src_strides[0] : -src_strides[0]); + if ((src_strides[0] < 0) && (shape[0] > 1)) { + src_offset += (shape[0] - 1) * src_strides[0]; + } assert(simplified_src_strides.size() == static_cast(nd)); simplified_dst_strides.reserve(nd); - simplified_dst_strides.push_back(dst_strides[0]); + simplified_dst_strides.push_back( + (dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]); + if ((dst_strides[0] < 0) && (shape[0] > 1)) { + dst_offset += (shape[0] - 1) * dst_strides[0]; + } assert(simplified_dst_strides.size() == static_cast(nd)); } } @@ -202,21 +217,36 @@ void simplify_iteration_space_3( nd = contracted_nd; } else if (nd == 1) { + src1_offset = 0; + src2_offset = 0; + dst_offset = 0; // Populate vectors simplified_shape.reserve(nd); simplified_shape.push_back(shape[0]); assert(simplified_shape.size() == static_cast(nd)); simplified_src1_strides.reserve(nd); - simplified_src1_strides.push_back(src1_strides[0]); + simplified_src1_strides.push_back( + (src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]); + if ((src1_strides[0] < 0) && (shape[0] > 1)) { + src1_offset += src1_strides[0] * (shape[0] - 1); + } assert(simplified_src1_strides.size() == static_cast(nd)); simplified_src2_strides.reserve(nd); - simplified_src2_strides.push_back(src2_strides[0]); + simplified_src2_strides.push_back( + (src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]); + if ((src2_strides[0] < 0) && (shape[0] > 1)) { + src2_offset += src2_strides[0] * (shape[0] - 1); + } assert(simplified_src2_strides.size() == static_cast(nd)); simplified_dst_strides.reserve(nd); - simplified_dst_strides.push_back(dst_strides[0]); + simplified_dst_strides.push_back( + (dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]); + if ((dst_strides[0] < 0) && (shape[0] > 1)) { + dst_offset += dst_strides[0] * (shape[0] - 1); + } assert(simplified_dst_strides.size() == static_cast(nd)); } } @@ -293,29 +323,129 @@ void simplify_iteration_space_4( nd = contracted_nd; } else if (nd == 1) { + src1_offset = 0; + src2_offset = 0; + src3_offset = 0; + dst_offset = 0; // Populate vectors simplified_shape.reserve(nd); simplified_shape.push_back(shape[0]); assert(simplified_shape.size() == static_cast(nd)); simplified_src1_strides.reserve(nd); - simplified_src1_strides.push_back(src1_strides[0]); + simplified_src1_strides.push_back( + (src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]); + if ((src1_strides[0] < 0) && (shape[0] > 1)) { + src1_offset += src1_strides[0] * (shape[0] - 1); + } assert(simplified_src1_strides.size() == static_cast(nd)); simplified_src2_strides.reserve(nd); - simplified_src2_strides.push_back(src2_strides[0]); + simplified_src2_strides.push_back( + (src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]); + if ((src2_strides[0] < 0) && (shape[0] > 1)) { + src2_offset += src2_strides[0] * (shape[0] - 1); + } assert(simplified_src2_strides.size() == static_cast(nd)); simplified_src3_strides.reserve(nd); - simplified_src3_strides.push_back(src3_strides[0]); + simplified_src3_strides.push_back( + (src3_strides[0] >= 0) ? src3_strides[0] : -src3_strides[0]); + if ((src3_strides[0] < 0) && (shape[0] > 1)) { + src3_offset += src3_strides[0] * (shape[0] - 1); + } assert(simplified_src3_strides.size() == static_cast(nd)); simplified_dst_strides.reserve(nd); - simplified_dst_strides.push_back(dst_strides[0]); + simplified_dst_strides.push_back( + (dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]); + if ((dst_strides[0] < 0) && (shape[0] > 1)) { + dst_offset += dst_strides[0] * (shape[0] - 1); + } assert(simplified_dst_strides.size() == static_cast(nd)); } } +py::ssize_t _ravel_multi_index_c(std::vector const &mi, + std::vector const &shape) +{ + size_t nd = shape.size(); + if (nd != mi.size()) { + throw py::value_error( + "Multi-index and shape vectors must have the same length."); + } + + py::ssize_t flat_index = 0; + py::ssize_t s = 1; + for (size_t i = 0; i < nd; ++i) { + flat_index += mi.at(nd - 1 - i) * s; + s *= shape.at(nd - 1 - i); + } + + return flat_index; +} + +py::ssize_t _ravel_multi_index_f(std::vector const &mi, + std::vector const &shape) +{ + size_t nd = shape.size(); + if (nd != mi.size()) { + throw py::value_error( + "Multi-index and shape vectors must have the same length."); + } + + py::ssize_t flat_index = 0; + py::ssize_t s = 1; + for (size_t i = 0; i < nd; ++i) { + flat_index += mi.at(i) * s; + s *= shape.at(i); + } + + return flat_index; +} + +std::vector _unravel_index_c(py::ssize_t flat_index, + std::vector const &shape) +{ + size_t nd = shape.size(); + std::vector mi; + mi.resize(nd); + + py::ssize_t i_ = flat_index; + for (size_t dim = 0; dim + 1 < nd; ++dim) { + const py::ssize_t si = shape[nd - 1 - dim]; + const py::ssize_t q = i_ / si; + const py::ssize_t r = (i_ - q * si); + mi[nd - 1 - dim] = r; + i_ = q; + } + if (nd) { + mi[0] = i_; + } + return mi; +} + +std::vector _unravel_index_f(py::ssize_t flat_index, + std::vector const &shape) +{ + size_t nd = shape.size(); + std::vector mi; + mi.resize(nd); + + py::ssize_t i_ = flat_index; + for (size_t dim = 0; dim + 1 < nd; ++dim) { + const py::ssize_t si = shape[dim]; + const py::ssize_t q = i_ / si; + const py::ssize_t r = (i_ - q * si); + mi[dim] = r; + i_ = q; + } + if (nd) { + mi[nd - 1] = i_; + } + return mi; +} + } // namespace py_internal } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp b/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp index 9a60830d1d..356afca08d 100644 --- a/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp +++ b/dpctl/tensor/libtensor/source/simplify_iteration_space.hpp @@ -90,6 +90,14 @@ void simplify_iteration_space_4(int &, py::ssize_t &, py::ssize_t &); +py::ssize_t _ravel_multi_index_c(std::vector const &, + std::vector const &); +py::ssize_t _ravel_multi_index_f(std::vector const &, + std::vector const &); +std::vector _unravel_index_c(py::ssize_t, + std::vector const &); +std::vector _unravel_index_f(py::ssize_t, + std::vector const &); } // namespace py_internal } // namespace tensor } // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_py.cpp b/dpctl/tensor/libtensor/source/tensor_py.cpp index 00f2b40869..2cf627be18 100644 --- a/dpctl/tensor/libtensor/source/tensor_py.cpp +++ b/dpctl/tensor/libtensor/source/tensor_py.cpp @@ -42,6 +42,7 @@ #include "full_ctor.hpp" #include "integer_advanced_indexing.hpp" #include "linear_sequences.hpp" +#include "simplify_iteration_space.hpp" #include "triul_ctor.hpp" #include "utils/memory_overlap.hpp" #include "utils/strided_iters.hpp" @@ -182,6 +183,37 @@ PYBIND11_MODULE(_tensor_impl, m) "as the original " "iterator, possibly in a different order."); + static constexpr char orderC = 'C'; + m.def( + "_ravel_multi_index", + [](const std::vector &mi, + const std::vector &shape, char order = 'C') { + if (order == orderC) { + return dpctl::tensor::py_internal::_ravel_multi_index_c(mi, + shape); + } + else { + return dpctl::tensor::py_internal::_ravel_multi_index_f(mi, + shape); + } + }, + ""); + + m.def( + "_unravel_index", + [](py::ssize_t flat_index, const std::vector &shape, + char order = 'C') { + if (order == orderC) { + return dpctl::tensor::py_internal::_unravel_index_c(flat_index, + shape); + } + else { + return dpctl::tensor::py_internal::_unravel_index_f(flat_index, + shape); + } + }, + ""); + m.def("_copy_usm_ndarray_for_reshape", ©_usm_ndarray_for_reshape, "Copies from usm_ndarray `src` into usm_ndarray `dst` with the same " "number of elements using underlying 'C'-contiguous order for flat " diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index affdbcc4da..cd4782d2ff 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1293,6 +1293,26 @@ def test_reshape(): assert A4.shape == requested_shape +def test_reshape_zero_size(): + try: + a = dpt.empty((0,)) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + with pytest.raises(ValueError): + dpt.reshape(a, (-1, 0)) + + +def test_reshape_large_ndim(): + ndim = 32 + idx = tuple(1 if i + 1 < ndim else ndim for i in range(ndim)) + try: + d = dpt.ones(ndim, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + d = dpt.reshape(d, idx) + assert d.shape == idx + + def test_reshape_copy_kwrd(): try: X = dpt.usm_ndarray((2, 3), "i4")