From 36d21df575be48d78acbecc7ae5c4307aab49c6b Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 7 Oct 2025 03:28:05 -0700 Subject: [PATCH 01/18] Implement getrs_batch lapack extension --- dpnp/backend/extensions/lapack/CMakeLists.txt | 1 + dpnp/backend/extensions/lapack/getrs.hpp | 15 + .../backend/extensions/lapack/getrs_batch.cpp | 336 ++++++++++++++++++ dpnp/backend/extensions/lapack/lapack_py.cpp | 13 +- .../extensions/lapack/types_matrix.hpp | 28 ++ 5 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 dpnp/backend/extensions/lapack/getrs_batch.cpp diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 1f35ee4d47c1..8e75d42882a2 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -37,6 +37,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/getrs_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index d8952f3f0b3f..9babe85ae90d 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -40,5 +40,20 @@ extern std::pair oneapi::mkl::transpose trans, const std::vector &depends = {}); +extern std::pair + getrs_batch(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &a_array, + const dpctl::tensor::usm_ndarray &ipiv_array, + const dpctl::tensor::usm_ndarray &b_array, + oneapi::mkl::transpose trans, + std::int64_t n, + std::int64_t nrhs, + std::int64_t stride_a, + std::int64_t stride_ipiv, + std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &depends = {}); + extern void init_getrs_dispatch_vector(void); +extern void init_getrs_batch_dispatch_vector(void); } // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp new file mode 100644 index 000000000000..f95a85d2b25e --- /dev/null +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -0,0 +1,336 @@ +//***************************************************************************** +// Copyright (c) 2025, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include +#include + +#include +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +#include "getrs.hpp" +#include "linalg_exceptions.hpp" +#include "types_matrix.hpp" + +namespace dpnp::extensions::lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; +namespace td_ns = dpctl::tensor::type_dispatch; + +typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( + sycl::queue &, + oneapi::mkl::transpose, // trans + const std::int64_t, // n + const std::int64_t, // nrhs + char *, // a + std::int64_t, // lda + std::int64_t, // stride_a + std::int64_t *, // ipiv + std::int64_t, // stride_ipiv + char *, // b + std::int64_t, // ldb + std::int64_t, // stride_b + std::int64_t, // batch_size + std::vector &, + const std::vector &); + +static getrs_batch_impl_fn_ptr_t getrs_batch_dispatch_vector[td_ns::num_types]; + +template +static sycl::event getrs_batch_impl(sycl::queue &exec_q, + oneapi::mkl::transpose trans, + const std::int64_t n, + const std::int64_t nrhs, + char *in_a, + std::int64_t lda, + std::int64_t stride_a, + std::int64_t *ipiv, + std::int64_t stride_ipiv, + char *in_b, + std::int64_t ldb, + std::int64_t stride_b, + std::int64_t batch_size, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *b = reinterpret_cast(in_b); + + const std::int64_t scratchpad_size = + mkl_lapack::getrs_batch_scratchpad_size(exec_q, trans, n, nrhs, lda, + stride_a, stride_ipiv, ldb, + stride_b, batch_size); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event getrs_batch_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + getrs_batch_event = mkl_lapack::getrs_batch( + exec_q, + trans, // Specifies the operation: whether or not to transpose + // matrix A. Can be 'N' for no transpose, 'T' for transpose, + // and 'C' for conjugate transpose. + n, // The order of the square matrix A + // and the number of rows in matrix B (0 ≤ n). + // It must be a non-negative integer. + nrhs, // The number of right-hand sides, + // i.e., the number of columns in matrix B (0 ≤ nrhs). + a, // Pointer to the square matrix A (n x n). + lda, // The leading dimension of matrix A, must be at least max(1, + // n). It must be at least max(1, n). + stride_a, // Stride between consecutive A matrices in the batch. + ipiv, // Pointer to the output array of pivot indices that were used + // during factorization (n, ). + stride_ipiv, // Stride between consecutive pivot arrays in the + // batch. + b, // Pointer to the matrix B of right-hand sides (ldb, nrhs). + ldb, // The leading dimension of matrix B, must be at least max(1, + // n). + stride_b, // Stride between consecutive B matrices in the batch. + batch_size, // Total number of matrices in the batch. + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else if (info > 0) { + is_exception_caught = false; + if (scratchpad != nullptr) { + dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, + exec_q); + } + throw LinAlgError("The solve could not be completed."); + } + else { + error_msg << "Unexpected MKL exception caught during getrs() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during getrs() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, exec_q); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(getrs_batch_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { + dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, ctx); + }); + }); + host_task_events.push_back(clean_up_event); + return getrs_batch_event; +} + +std::pair + getrs_batch(sycl::queue &exec_q, + const dpctl::tensor::usm_ndarray &a_array, + const dpctl::tensor::usm_ndarray &ipiv_array, + const dpctl::tensor::usm_ndarray &b_array, + oneapi::mkl::transpose trans, + std::int64_t n, + std::int64_t nrhs, + std::int64_t stride_a, + std::int64_t stride_ipiv, + std::int64_t stride_b, + std::int64_t batch_size, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int b_array_nd = b_array.get_ndim(); + const int ipiv_array_nd = ipiv_array.get_ndim(); + + if (a_array_nd < 3) { + throw py::value_error( + "The LU-factorized array has ndim=" + std::to_string(a_array_nd) + + ", but an array with ndim >= 3 is expected"); + } + if (b_array_nd < 3) { + throw py::value_error("The right-hand sides array has ndim=" + + std::to_string(b_array_nd) + + ", but an array with ndim >= 3 is expected"); + } + if (ipiv_array_nd < 1) { + throw py::value_error("The array of pivot indices has ndim=" + + std::to_string(ipiv_array_nd) + + ", but an array with ndim >= 2 is expected"); + } + + if (ipiv_array_nd != a_array_nd - 1) { + throw py::value_error( + "The array of pivot indices has ndim=" + + std::to_string(ipiv_array_nd) + + ", but an array with ndim=" + std::to_string(a_array_nd - 1) + + " is expected to match LU batch dimensions"); + } + + const py::ssize_t *a_array_shape = a_array.get_shape_raw(); + + if (a_array_shape[a_array_nd - 1] != a_array_shape[a_array_nd - 2]) { + throw py::value_error( + "The last two dimensions of the LU array must be square," + " but got a shape of (" + + std::to_string(a_array_shape[a_array_nd - 1]) + ", " + + std::to_string(a_array_shape[a_array_nd - 2]) + ")."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(exec_q, + {a_array, b_array, ipiv_array})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, b_array)) { + throw py::value_error("The LU-factorized and right-hand sides arrays " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + bool is_a_array_f_contig = a_array.is_f_contiguous(); + bool is_b_array_f_contig = b_array.is_f_contiguous(); + bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous(); + bool is_ipiv_array_f_contig = ipiv_array.is_f_contiguous(); + if (!is_a_array_c_contig && !is_a_array_f_contig) { + throw py::value_error("The LU-factorized array " + "must be either C-contiguous " + "or F-contiguous"); + } + if (!is_b_array_f_contig) { + throw py::value_error("The right-hand sides array " + "must be F-contiguous"); + } + if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) { + throw py::value_error("The array of pivot indices " + "must be contiguous"); + } + + auto array_types = td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int b_array_type_id = + array_types.typenum_to_lookup_id(b_array.get_typenum()); + + if (a_array_type_id != b_array_type_id) { + throw py::value_error("The types of the LU-factorized and " + "right-hand sides arrays are mismatched"); + } + + getrs_batch_impl_fn_ptr_t getrs_batch_fn = + getrs_batch_dispatch_vector[a_array_type_id]; + if (getrs_batch_fn == nullptr) { + throw py::value_error( + "No getrs_batch implementation defined for the provided type " + "of the input matrix."); + } + + auto ipiv_types = td_ns::usm_ndarray_types(); + int ipiv_array_type_id = + ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); + + if (ipiv_array_type_id != static_cast(td_ns::typenum_t::INT64)) { + throw py::value_error("The type of 'ipiv_array' must be int64."); + } + + const std::int64_t lda = std::max(1UL, n); + const std::int64_t ldb = std::max(1UL, n); + + char *a_array_data = a_array.get_data(); + char *b_array_data = b_array.get_data(); + char *ipiv_array_data = ipiv_array.get_data(); + + std::int64_t *ipiv = reinterpret_cast(ipiv_array_data); + + std::vector host_task_events; + sycl::event getrs_batch_ev = getrs_batch_fn( + exec_q, trans, n, nrhs, a_array_data, lda, stride_a, ipiv, stride_ipiv, + b_array_data, ldb, stride_b, batch_size, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive( + exec_q, {a_array, b_array, ipiv_array}, host_task_events); + + return std::make_pair(args_ev, getrs_batch_ev); +} + +template +struct GetrsBatchContigFactory +{ + fnT get() + { + if constexpr (types::GetrsBatchTypePairSupportFactory::is_defined) { + return getrs_batch_impl; + } + else { + return nullptr; + } + } +}; + +void init_getrs_batch_dispatch_vector(void) +{ + td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(getrs_batch_dispatch_vector); +} +} // namespace dpnp::extensions::lapack diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 46471cc2f366..b40c8e38876f 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -58,6 +58,7 @@ void init_dispatch_vectors(void) lapack_ext::init_getrf_batch_dispatch_vector(); lapack_ext::init_getrf_dispatch_vector(); lapack_ext::init_getri_batch_dispatch_vector(); + lapack_ext::init_getrs_batch_dispatch_vector(); lapack_ext::init_getrs_dispatch_vector(); lapack_ext::init_orgqr_batch_dispatch_vector(); lapack_ext::init_orgqr_dispatch_vector(); @@ -164,12 +165,22 @@ PYBIND11_MODULE(_lapack_impl, m) m.def("_getrs", &lapack_ext::getrs, "Call `getrs` from OneMKL LAPACK library to return " - "the solves of linear equations with an LU-factored " + "the solutions of linear equations with an LU-factored " "square coefficient matrix, with multiple right-hand sides", py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N, py::arg("depends") = py::list()); + m.def("_getrs_batch", &lapack_ext::getrs_batch, + "Call `getrs_batch` from OneMKL LAPACK library to return " + "the solutions of batch linear equations with an LU-factored " + "square coefficient matrix, with multiple right-hand sides", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), + py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N, + py::arg("n"), py::arg("nrhs"), py::arg("stride_a"), + py::arg("stride_ipiv"), py::arg("stride_b"), py::arg("batch_size"), + py::arg("depends") = py::list()); + m.def("_orgqr_batch", &lapack_ext::orgqr_batch, "Call `_orgqr_batch` from OneMKL LAPACK library to return " "the real orthogonal matrix Qi of the QR factorization " diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index 81e287d5c163..7fc4fc725550 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -245,6 +245,34 @@ struct GetrsTypePairSupportFactory dpctl_td_ns::NotDefinedEntry>::is_defined; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::getrs_batch + * function. + * + * @tparam T Type of array containing batched input matrix (LU-factored form) + * and the array of multiple dependent variables, + * as well as the output array for storing the solutions to a system of linear + * equations. + */ +template +struct GetrsBatchTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL LAPACK library provides support in oneapi::mkl::lapack::heevd From 4270823ed17f7eb19b6e1563137b776e1a7c1e81 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 05:51:00 -0700 Subject: [PATCH 02/18] Add more validation checks in getrs_batch --- .../backend/extensions/lapack/getrs_batch.cpp | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp index f95a85d2b25e..b2905e82690c 100644 --- a/dpnp/backend/extensions/lapack/getrs_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -128,6 +128,20 @@ static sycl::event getrs_batch_impl(sycl::queue &exec_q, scratchpad, // Pointer to scratchpad memory to be used by MKL // routine for storing intermediate results. scratchpad_size, depends); + } catch (mkl_lapack::batch_error const &be) { + // Get the indices of matrices within the batch that encountered an + // error + auto error_matrices_ids = be.ids(); + + // OneMKL batched functions throw a single `batch_error` + // instead of per-matrix exceptions or an info array. + // This is interpreted as a computation_error (singular matrix), + // consistent with non-batched LAPACK behavior. + is_exception_caught = false; + if (scratchpad != nullptr) { + dpctl::tensor::alloc_utils::sycl_free_noexcept(scratchpad, exec_q); + } + throw LinAlgError("The solve could not be completed."); } catch (mkl_lapack::exception const &e) { is_exception_caught = true; info = e.info(); @@ -203,10 +217,10 @@ std::pair "The LU-factorized array has ndim=" + std::to_string(a_array_nd) + ", but an array with ndim >= 3 is expected"); } - if (b_array_nd < 3) { + if (b_array_nd < 2) { throw py::value_error("The right-hand sides array has ndim=" + std::to_string(b_array_nd) + - ", but an array with ndim >= 3 is expected"); + ", but an array with ndim >= 2 is expected"); } if (ipiv_array_nd < 1) { throw py::value_error("The array of pivot indices has ndim=" + @@ -214,6 +228,14 @@ std::pair ", but an array with ndim >= 2 is expected"); } + const py::ssize_t *a_array_shape = a_array.get_shape_raw(); + if (a_array_shape[0] != a_array_shape[1]) { + throw py::value_error("Expected batch of square matrices , but got " + "matrix shape (" + + std::to_string(a_array_shape[0]) + ", " + + std::to_string(a_array_shape[1]) + ") in batch"); + } + if (ipiv_array_nd != a_array_nd - 1) { throw py::value_error( "The array of pivot indices has ndim=" + @@ -222,16 +244,6 @@ std::pair " is expected to match LU batch dimensions"); } - const py::ssize_t *a_array_shape = a_array.get_shape_raw(); - - if (a_array_shape[a_array_nd - 1] != a_array_shape[a_array_nd - 2]) { - throw py::value_error( - "The last two dimensions of the LU array must be square," - " but got a shape of (" + - std::to_string(a_array_shape[a_array_nd - 1]) + ", " + - std::to_string(a_array_shape[a_array_nd - 2]) + ")."); - } - // check compatibility of execution queue and allocation queue if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, b_array, ipiv_array})) @@ -281,7 +293,7 @@ std::pair if (getrs_batch_fn == nullptr) { throw py::value_error( "No getrs_batch implementation defined for the provided type " - "of the input matrix."); + "of the input matrix"); } auto ipiv_types = td_ns::usm_ndarray_types(); @@ -289,7 +301,7 @@ std::pair ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); if (ipiv_array_type_id != static_cast(td_ns::typenum_t::INT64)) { - throw py::value_error("The type of 'ipiv_array' must be int64."); + throw py::value_error("The type of 'ipiv_array' must be int64"); } const std::int64_t lda = std::max(1UL, n); From 14feed3d2deed28f45b16aed97a7961d4122c10a Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 06:43:27 -0700 Subject: [PATCH 03/18] Implement _batched_lu_solve --- dpnp/linalg/dpnp_utils_linalg.py | 180 +++++++++++++++++++++++++++---- 1 file changed, 159 insertions(+), 21 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 44a3816cc165..16dc78e7f68a 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -107,6 +107,37 @@ class SVDResult(NamedTuple): } +def _align_lu_solve_broadcast(lu, b): + """Align LU and RHS batch dimensions with SciPy-like rules.""" + lu_shape = lu.shape + b_shape = b.shape + + if b.ndim < 2: + if lu_shape[-2] != b_shape[0]: + raise ValueError( + f"Shapes of lu {lu_shape} and b {b_shape} are incompatible" + ) + b = dpnp.broadcast_to(b, lu_shape[:-1]) + return lu, b + + if lu_shape[-2] != b_shape[-2]: + raise ValueError( + f"Shapes of lu {lu_shape} and b {b_shape} are incompatible" + ) + + # Use dpnp.broadcast_shapes() to align the resulting batch shapes + batch = dpnp.broadcast_shapes(lu_shape[:-2], b_shape[:-2]) + lu_bshape = batch + lu_shape[-2:] + b_bshape = batch + b_shape[-2:] + + if lu_shape != lu_bshape: + lu = dpnp.broadcast_to(lu, lu_bshape) + if b_shape != b_bshape: + b = dpnp.broadcast_to(b, b_bshape) + + return lu, b + + def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type): """ _batched_eigh(a, UPLO, eigen_mode, w_type, v_type) @@ -486,6 +517,109 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals return (a_h, ipiv_h) +def _batched_lu_solve(lu, piv, b, res_type, trans=0): + """Solve a batched equation system (SciPy-compatible behavior).""" + res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) + + if b.size == 0: + return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) + + b_ndim = b.ndim + + lu, b = _align_lu_solve_broadcast(lu, b) + + n = lu.shape[-1] + nrhs = b.shape[-1] if b_ndim > 1 else 1 + + # get 3d input arrays by reshape + if lu.ndim > 3: + lu = dpnp.reshape(lu, (-1, n, n)) + # get 2d pivot arrays by reshape + if piv.ndim > 2: + piv = dpnp.reshape(piv, (-1, n)) + batch_size = lu.shape[0] + + # Move batch axis to the end (n, n, batch) in Fortran order: + # required by getrs_batch + # and ensures each a[..., i] is F-contiguous for getrs_batch + lu = dpnp.moveaxis(lu, 0, -1) + + b_orig_shape = b.shape + if b.ndim > 2: + b = dpnp.reshape(b, (-1, n, nrhs)) + + # Move batch axis to the end (n, nrhs, batch) in Fortran order: + # required by getrs_batch + # and ensures each b[..., i] is F-contiguous for getrs_batch + b = dpnp.moveaxis(b, 0, -1) + + lu_usm_arr = dpnp.get_usm_ndarray(lu) + b_usm_arr = dpnp.get_usm_ndarray(b) + + # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy, + # convert to 1-based for oneMKL getrs_batch + piv_h = piv + 1 + + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + # oneMKL LAPACK getrs overwrites `lu`. + lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=lu_usm_arr, + dst=lu_h.get_array(), + sycl_queue=lu.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, lu_copy_ev) + + b_h = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type) + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_h.get_array(), + sycl_queue=b.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, b_copy_ev) + dep_evs = [lu_copy_ev, b_copy_ev] + + lu_stride = lu_h.strides[-1] + piv_stride = piv.strides[0] + b_stride = b_h.strides[-1] + + if not isinstance(trans, int): + raise TypeError("`trans` must be an integer") + + trans_mkl = _map_trans_to_mkl(trans) + + # Call the LAPACK extension function _getrs_batch + # to solve the system of linear equations with an LU-factored + # coefficient square matrix, with multiple right-hand sides. + ht_ev, getrs_batch_ev = li._getrs_batch( + exec_q, + lu_h.get_array(), + piv_h.get_array(), + b_h.get_array(), + trans_mkl, + n, + nrhs, + lu_stride, + piv_stride, + b_stride, + batch_size, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, getrs_batch_ev) + + # Restore original shape: move batch axis back and reshape + b_h = dpnp.moveaxis(b_h, -1, 0).reshape(b_orig_shape) + + return b_h + + def _batched_solve(a, b, exec_q, res_usm_type, res_type): """ _batched_solve(a, b, exec_q, res_usm_type, res_type) @@ -1099,6 +1233,20 @@ def _is_empty_2d(arr): return arr.size == 0 and numpy.prod(arr.shape[-2:]) == 0 +def _map_trans_to_mkl(trans): + """Map SciPy-style trans code (0,1,2) to oneMKL transpose enum.""" + if not isinstance(trans, int): + raise TypeError("`trans` must be an integer") + + if trans == 0: + return li.Transpose.N + if trans == 1: + return li.Transpose.T + if trans == 2: + return li.Transpose.C + raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") + + def _lu_factor(a, res_type): """ Compute pivoted LU decomposition. @@ -2493,18 +2641,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): res_type = _common_type(lu, b) - # TODO: add broadcasting - if lu.shape[0] != b.shape[0]: - raise ValueError( - f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" - ) - if b.size == 0: return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) - if lu.ndim > 2: - raise NotImplementedError("Batched matrices are not supported") - if check_finite: if not dpnp.isfinite(lu).all(): raise ValueError( @@ -2517,6 +2656,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): "Right-hand side array must not contain infs or NaNs" ) + if lu.ndim > 2: + # SciPy always copies each 2D slice, + # so `overwrite_b` is ignored here + return _batched_lu_solve(lu, piv, b, trans=trans, res_type=res_type) + + if lu.shape[0] != b.shape[0]: + raise ValueError( + f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" + ) + lu_usm_arr = dpnp.get_usm_ndarray(lu) b_usm_arr = dpnp.get_usm_ndarray(b) @@ -2563,18 +2712,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): b_h = b dep_evs = [lu_copy_ev] - if not isinstance(trans, int): - raise TypeError("`trans` must be an integer") - - # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums - if trans == 0: - trans_mkl = li.Transpose.N - elif trans == 1: - trans_mkl = li.Transpose.T - elif trans == 2: - trans_mkl = li.Transpose.C - else: - raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") + trans_mkl = _map_trans_to_mkl(trans) # Call the LAPACK extension function _getrs # to solve the system of linear equations with an LU-factored From 5e21a027df25b4e99997b98795ed93eb6e0de05c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 06:48:42 -0700 Subject: [PATCH 04/18] Add TestLuSolveBatched --- dpnp/tests/test_linalg.py | 238 +++++++++++++++++++++++++++++++++++++- 1 file changed, 235 insertions(+), 3 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index a25c237f846e..440509ddb6f3 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2300,9 +2300,6 @@ def test_strided_rhs(self): (4,), (4, 1), (4, 3), - # (1, 4, 3), - # (2, 4, 3), - # (1, 1, 4, 3) ], ) def test_broadcast_rhs(self, b_shape): @@ -2358,6 +2355,241 @@ def test_check_finite_raises(self, bad): ) +class TestLuSolveBatched: + @staticmethod + def _make_nonsingular_nd_np(shape, dtype, order): + A = generate_random_numpy_array(shape, dtype, order) + n = shape[-1] + A3 = A.reshape((-1, n, n)) + for B in A3: + off = numpy.sum(numpy.abs(B), axis=1) - numpy.abs(numpy.diag(B)) + B[numpy.arange(n), numpy.arange(n)] = A.dtype.type(off + 1.0) + A = A3.reshape(shape) + # Ensure reshapes did not break memory order + A = numpy.array(A, order=order) + return A + + @staticmethod + def _expected_x_shape(a_shape, b_shape): + n = a_shape[-1] + assert a_shape[-2] == n + + a_batch = a_shape[:-2] + if len(b_shape) >= 2 and b_shape[-2] == n: + # b : (..., n, nrhs) + k = b_shape[-1] + b_batch = b_shape[:-2] + exp_batch = numpy.broadcast_shapes(a_batch, b_batch) + return exp_batch + (n, k) + else: + # b : (..., n) + assert b_shape[-1] == n, "b's last dim must equal n" + b_batch = b_shape[:-1] + exp_batch = numpy.broadcast_shapes(a_batch, b_batch) + return exp_batch + (n,) + + @pytest.mark.parametrize( + "a_shape, b_shape", + [ + ((1, 2, 2), (2,)), + ((2, 4, 4), (4,)), + ((2, 4, 4), (4, 3)), + ((2, 4, 4), (2, 4, 4)), + ((2, 4, 4), (1, 4, 3)), + ((2, 4, 4), (2, 4, 2)), + ((2, 3, 4, 4), (1, 3, 4, 2)), + ((2, 3, 4, 4), (2, 1, 4, 2)), + ((3, 4, 4), (1, 1, 4, 2)), + ], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_none=True) + ) + def test_lu_solve_batched(self, a_shape, b_shape, dtype, order): + a_np = self._make_nonsingular_nd_np(a_shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + b_np = generate_random_numpy_array(b_shape, dtype, order) + b_dp = dpnp.array(b_np, order=order) + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=True, check_finite=False + ) + + exp_shape = self._expected_x_shape(a_shape, b_shape) + assert x.shape == exp_shape + + if b_dp.ndim > 1: + Ax = a_dp @ x + else: + Ax = (a_dp @ x[..., None])[..., 0] + b_exp = dpnp.broadcast_to(b_dp, exp_shape) + assert dpnp.allclose(Ax, b_exp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("trans", [0, 1, 2]) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_trans(self, trans, order, dtype): + a_shape = (3, 4, 4) + b_shape = (3, 4, 2) + + a_np = self._make_nonsingular_nd_np(a_shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + b_dp = dpnp.array( + generate_random_numpy_array(b_shape, dtype, order), order=order + ) + + lu, piv = dpnp.linalg.lu_factor( + a_dp, overwrite_a=False, check_finite=False + ) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False + ) + + if trans == 0: + lhs = a_dp @ x + elif trans == 1: + lhs = dpnp.swapaxes(a_dp, -1, -2) @ x + else: # trans == 2 + lhs = dpnp.conj(dpnp.swapaxes(a_dp, -1, -2)) @ x + + assert dpnp.allclose(lhs, b_dp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_overwrite(self, dtype, order): + a_np = self._make_nonsingular_nd_np((2, 4, 4), dtype, order) + a_dp = dpnp.array(a_np, order=order) + + lu, piv = dpnp.linalg.lu_factor( + a_dp, overwrite_a=False, check_finite=False + ) + + b_dp = dpnp.array( + generate_random_numpy_array((2, 4, 2), dtype, "F"), order="F" + ) + b_dp_orig = b_dp.copy() + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=True, check_finite=False + ) + + assert x is not b_dp + assert dpnp.allclose(b_dp, b_dp_orig) + + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5) + + def test_strided(self): + n, B = 4, 6 + a_np = self._make_nonsingular_nd_np( + (B, n, n), dpnp.default_float_type(), "F" + ) + a_dp = dpnp.array(a_np, order="F") + + a_stride = a_dp[::2] + rhs_full = ( + dpnp.arange(B * n * 3, dtype=dpnp.default_float_type()).reshape( + B, n, 3, order="F" + ) + + 1.0 + ) + b_dp = rhs_full[::2, :, ::-1] + + lu, piv = dpnp.linalg.lu_factor(a_stride, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=False, check_finite=False + ) + + assert dpnp.allclose(a_stride @ x, b_dp, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize( + "dtype_a", get_all_dtypes(no_bool=True, no_none=True) + ) + @pytest.mark.parametrize( + "dtype_b", get_all_dtypes(no_bool=True, no_none=True) + ) + @pytest.mark.parametrize("b_shape", [(4, 2), (1, 4, 2), (2, 4, 2)]) + def test_diff_type(self, dtype_a, dtype_b, b_shape): + B, n, k = 2, 4, 2 + a_np = self._make_nonsingular_nd_np((B, n, n), dtype_a, "F") + a_dp = dpnp.array(a_np, order="F") + + b_np = generate_random_numpy_array(b_shape, dtype_b, "F") + b_dp = dpnp.array(b_np, order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + + exp_shape = (B, n, k) + assert x.shape == exp_shape + + b_exp = dpnp.broadcast_to(b_dp, exp_shape) + assert dpnp.allclose( + a_dp @ x, b_exp.astype(x.dtype, copy=False), rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.parametrize( + "a_shape, b_shape", + [ + ((0, 3, 3), (0, 3)), + ((2, 0, 0), (2, 0)), + ((0, 0, 0), (0, 0)), + ], + ) + def test_empty_inputs(self, a_shape, b_shape): + a = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F") + b = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") + + lu, piv = dpnp.linalg.lu_factor(a, check_finite=False) + x = dpnp.linalg.lu_solve((lu, piv), b, check_finite=False) + + assert x.shape == b_shape + + def test_check_finite_raises(self): + B, n = 2, 3 + a_np = self._make_nonsingular_nd_np( + (B, n, n), dpnp.default_float_type(), "F" + ) + a_dp = dpnp.array(a_np, order="F") + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + + b_bad = dpnp.ones((B, n), dtype=dpnp.default_float_type(), order="F") + b_bad[1, 0] = dpnp.nan + assert_raises( + ValueError, + dpnp.linalg.lu_solve, + (lu, piv), + b_bad, + check_finite=True, + ) + + @pytest.mark.parametrize( + "a_shape, b_shape", + [ + ((2, 4, 4), (2,)), + ((2, 4, 4), (2, 4)), + ((2, 4, 4), (4, 4, 2)), + ((2, 4, 4), (2, 3, 4, 2)), + ((2, 3, 4, 4), (3, 4)), + ((2, 3, 4, 4), (2, 4)), + ((2, 3, 4, 4), (2, 3, 5, 2)), + ], + ) + def test_invalid_shapes(self, a_shape, b_shape): + dtype = dpnp.default_float_type() + a = dpnp.array( + self._make_nonsingular_nd_np(a_shape, dtype, "F"), order="F" + ) + b = dpnp.array( + generate_random_numpy_array(b_shape, dtype, "F"), order="F" + ) + + lu, piv = dpnp.linalg.lu_factor(a, check_finite=False) + with pytest.raises(ValueError): + dpnp.linalg.lu_solve((lu, piv), b, check_finite=False) + + class TestMatrixPower: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( From a13cd88a6b778ca233fb35469a0f2fc1bc5ca40b Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 10:18:19 -0700 Subject: [PATCH 05/18] Compute strides explicitly --- dpnp/linalg/dpnp_utils_linalg.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 16dc78e7f68a..a40831623e3e 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -586,12 +586,9 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0): _manager.add_event_pair(ht_ev, b_copy_ev) dep_evs = [lu_copy_ev, b_copy_ev] - lu_stride = lu_h.strides[-1] - piv_stride = piv.strides[0] - b_stride = b_h.strides[-1] - - if not isinstance(trans, int): - raise TypeError("`trans` must be an integer") + lu_stride = n * n + piv_stride = n + b_stride = n * nrhs trans_mkl = _map_trans_to_mkl(trans) From b68d80bfece6fa8d4bcec203cd7b0da8f1daf6f5 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 10:21:56 -0700 Subject: [PATCH 06/18] Update test_lu_solve in test_usm_type/sycl_queue.py --- dpnp/tests/test_sycl_queue.py | 13 +++++++++---- dpnp/tests/test_usm_type.py | 13 +++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index f4299325762a..32a9f214bca6 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -1612,11 +1612,16 @@ def test_lu_factor(self, data, device): assert_sycl_queue_equal(param_queue, a.sycl_queue) @pytest.mark.parametrize( - "b_data", - [[1.0, 2.0], numpy.empty((2, 0))], + "a_data, b_data", + [ + ([[1.0, 2.0], [3.0, 5.0]], [1.0, 2.0]), + ([[1.0, 2.0], [3.0, 5.0]], numpy.empty((2, 0))), + ([[[1.0, 2.0], [3.0, 5.0]]], [1.0, 2.0]), + ([[[1.0, 2.0], [3.0, 5.0]]], numpy.empty((2, 0, 2))), + ], ) - def test_lu_solve(self, b_data, device): - a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) + def test_lu_solve(self, a_data, b_data, device): + a = dpnp.array(a_data, device=device) lu, piv = dpnp.linalg.lu_factor(a) b = dpnp.array(b_data, device=device) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index c17526649ab3..573761394be7 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1489,11 +1489,16 @@ def test_lu_factor(self, data, usm_type): @pytest.mark.parametrize("usm_type_rhs", list_of_usm_types) @pytest.mark.parametrize( - "b_data", - [[1.0, 2.0], numpy.empty((2, 0))], + "a_data, b_data", + [ + ([[1.0, 2.0], [3.0, 5.0]], [1.0, 2.0]), + ([[1.0, 2.0], [3.0, 5.0]], numpy.empty((2, 0))), + ([[[1.0, 2.0], [3.0, 5.0]]], [1.0, 2.0]), + ([[[1.0, 2.0], [3.0, 5.0]]], numpy.empty((2, 0, 2))), + ], ) - def test_lu_solve(self, b_data, usm_type, usm_type_rhs): - a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], usm_type=usm_type) + def test_lu_solve(self, a_data, b_data, usm_type, usm_type_rhs): + a = dpnp.array(a_data, usm_type=usm_type) lu, piv = dpnp.linalg.lu_factor(a) b = dpnp.array(b_data, usm_type=usm_type_rhs) From f6d77fe8d87052813a36a3daba4277e765baf780 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 11:33:25 -0700 Subject: [PATCH 07/18] Add square-matrix assertion in lu_solve --- dpnp/linalg/dpnp_iface_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index f73443229c49..f28ff0a017b2 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -1037,6 +1037,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): (lu, piv) = lu_and_piv dpnp.check_supported_arrays_type(lu, piv, b) assert_stacked_2d(lu) + assert_stacked_square(lu) return dpnp_lu_solve( lu, From eb8c58a732763e26735c20fb45bc9b196bc001ba Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 11:34:10 -0700 Subject: [PATCH 08/18] Update test_empty_shapes for lu_solve() --- dpnp/tests/test_linalg.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 440509ddb6f3..aa8ac0669225 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2320,19 +2320,15 @@ def test_broadcast_rhs(self, b_shape): assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-5, atol=1e-5) - @pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)]) - @pytest.mark.parametrize("rhs_cols", [None, 0, 3]) - def test_empty_shapes(self, shape, rhs_cols): - a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") - if min(shape) > 0: - for i in range(min(shape)): + @pytest.mark.parametrize("a_shape", [(0, 0), (5, 5)]) + @pytest.mark.parametrize("b_shape", [(0,), (0, 0), (0, 5)]) + def test_empty_shapes(self, a_shape, b_shape): + a_dp = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F") + n = a_shape[0] + + if n > 0: + for i in range(n): a_dp[i, i] = a_dp.dtype.type(1.0) - - n = shape[0] - if rhs_cols is None: - b_shape = (n,) - else: - b_shape = (n, rhs_cols) b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) @@ -2537,7 +2533,7 @@ def test_diff_type(self, dtype_a, dtype_b, b_shape): ((0, 0, 0), (0, 0)), ], ) - def test_empty_inputs(self, a_shape, b_shape): + def test_empty_shapes(self, a_shape, b_shape): a = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F") b = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") From 959f5f85ebf72e57010bdb27b1386eecda44eb09 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 12:16:58 -0700 Subject: [PATCH 09/18] Move changes to new location(scipy folder) --- dpnp/linalg/dpnp_utils_linalg.py | 45 -------- dpnp/scipy/linalg/_decomp_lu.py | 6 +- dpnp/scipy/linalg/_utils.py | 177 +++++++++++++++++++++++++++---- dpnp/tests/test_linalg.py | 32 +++--- 4 files changed, 177 insertions(+), 83 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 0654ef524c50..ee7650d4600d 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -106,37 +106,6 @@ class SVDResult(NamedTuple): } -def _align_lu_solve_broadcast(lu, b): - """Align LU and RHS batch dimensions with SciPy-like rules.""" - lu_shape = lu.shape - b_shape = b.shape - - if b.ndim < 2: - if lu_shape[-2] != b_shape[0]: - raise ValueError( - f"Shapes of lu {lu_shape} and b {b_shape} are incompatible" - ) - b = dpnp.broadcast_to(b, lu_shape[:-1]) - return lu, b - - if lu_shape[-2] != b_shape[-2]: - raise ValueError( - f"Shapes of lu {lu_shape} and b {b_shape} are incompatible" - ) - - # Use dpnp.broadcast_shapes() to align the resulting batch shapes - batch = dpnp.broadcast_shapes(lu_shape[:-2], b_shape[:-2]) - lu_bshape = batch + lu_shape[-2:] - b_bshape = batch + b_shape[-2:] - - if lu_shape != lu_bshape: - lu = dpnp.broadcast_to(lu, lu_bshape) - if b_shape != b_bshape: - b = dpnp.broadcast_to(b, b_bshape) - - return lu, b - - def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type): """ _batched_eigh(a, UPLO, eigen_mode, w_type, v_type) @@ -987,20 +956,6 @@ def _is_empty_2d(arr): return arr.size == 0 and numpy.prod(arr.shape[-2:]) == 0 -def _map_trans_to_mkl(trans): - """Map SciPy-style trans code (0,1,2) to oneMKL transpose enum.""" - if not isinstance(trans, int): - raise TypeError("`trans` must be an integer") - - if trans == 0: - return li.Transpose.N - if trans == 1: - return li.Transpose.T - if trans == 2: - return li.Transpose.C - raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") - - def _lu_factor(a, res_type): """ Compute pivoted LU decomposition. diff --git a/dpnp/scipy/linalg/_decomp_lu.py b/dpnp/scipy/linalg/_decomp_lu.py index 50a824a822cf..d149c16edfcd 100644 --- a/dpnp/scipy/linalg/_decomp_lu.py +++ b/dpnp/scipy/linalg/_decomp_lu.py @@ -39,7 +39,10 @@ import dpnp -from dpnp.linalg.dpnp_utils_linalg import assert_stacked_2d +from dpnp.linalg.dpnp_utils_linalg import ( + assert_stacked_2d, + assert_stacked_square, +) from ._utils import ( dpnp_lu_factor, @@ -184,6 +187,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): (lu, piv) = lu_and_piv dpnp.check_supported_arrays_type(lu, piv, b) assert_stacked_2d(lu) + assert_stacked_square(lu) return dpnp_lu_solve( lu, diff --git a/dpnp/scipy/linalg/_utils.py b/dpnp/scipy/linalg/_utils.py index 377c41ac35c4..609bebac4610 100644 --- a/dpnp/scipy/linalg/_utils.py +++ b/dpnp/scipy/linalg/_utils.py @@ -55,6 +55,37 @@ ] +def _align_lu_solve_broadcast(lu, b): + """Align LU and RHS batch dimensions with SciPy-like rules.""" + lu_shape = lu.shape + b_shape = b.shape + + if b.ndim < 2: + if lu_shape[-2] != b_shape[0]: + raise ValueError( + f"Shapes of lu {lu_shape} and b {b_shape} are incompatible" + ) + b = dpnp.broadcast_to(b, lu_shape[:-1]) + return lu, b + + if lu_shape[-2] != b_shape[-2]: + raise ValueError( + f"Shapes of lu {lu_shape} and b {b_shape} are incompatible" + ) + + # Use dpnp.broadcast_shapes() to align the resulting batch shapes + batch = dpnp.broadcast_shapes(lu_shape[:-2], b_shape[:-2]) + lu_bshape = batch + lu_shape[-2:] + b_bshape = batch + b_shape[-2:] + + if lu_shape != lu_bshape: + lu = dpnp.broadcast_to(lu, lu_bshape) + if b_shape != b_bshape: + b = dpnp.broadcast_to(b, b_bshape) + + return lu, b + + def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals """SciPy-compatible LU factorization for batched inputs.""" @@ -180,6 +211,106 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals return (a_h, ipiv_h) +def _batched_lu_solve(lu, piv, b, res_type, trans=0): + """Solve a batched equation system (SciPy-compatible behavior).""" + res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) + + if b.size == 0: + return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) + + b_ndim = b.ndim + + lu, b = _align_lu_solve_broadcast(lu, b) + + n = lu.shape[-1] + nrhs = b.shape[-1] if b_ndim > 1 else 1 + + # get 3d input arrays by reshape + if lu.ndim > 3: + lu = dpnp.reshape(lu, (-1, n, n)) + # get 2d pivot arrays by reshape + if piv.ndim > 2: + piv = dpnp.reshape(piv, (-1, n)) + batch_size = lu.shape[0] + + # Move batch axis to the end (n, n, batch) in Fortran order: + # required by getrs_batch + # and ensures each a[..., i] is F-contiguous for getrs_batch + lu = dpnp.moveaxis(lu, 0, -1) + + b_orig_shape = b.shape + if b.ndim > 2: + b = dpnp.reshape(b, (-1, n, nrhs)) + + # Move batch axis to the end (n, nrhs, batch) in Fortran order: + # required by getrs_batch + # and ensures each b[..., i] is F-contiguous for getrs_batch + b = dpnp.moveaxis(b, 0, -1) + + lu_usm_arr = dpnp.get_usm_ndarray(lu) + b_usm_arr = dpnp.get_usm_ndarray(b) + + # dpnp.linalg.lu_factor() returns 0-based pivots to match SciPy, + # convert to 1-based for oneMKL getrs_batch + piv_h = piv + 1 + + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + # oneMKL LAPACK getrs overwrites `lu`. + lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=lu_usm_arr, + dst=lu_h.get_array(), + sycl_queue=lu.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, lu_copy_ev) + + b_h = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type) + ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_h.get_array(), + sycl_queue=b.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, b_copy_ev) + dep_evs = [lu_copy_ev, b_copy_ev] + + lu_stride = n * n + piv_stride = n + b_stride = n * nrhs + + trans_mkl = _map_trans_to_mkl(trans) + + # Call the LAPACK extension function _getrs_batch + # to solve the system of linear equations with an LU-factored + # coefficient square matrix, with multiple right-hand sides. + ht_ev, getrs_batch_ev = li._getrs_batch( + exec_q, + lu_h.get_array(), + piv_h.get_array(), + b_h.get_array(), + trans_mkl, + n, + nrhs, + lu_stride, + piv_stride, + b_stride, + batch_size, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, getrs_batch_ev) + + # Restore original shape: move batch axis back and reshape + b_h = dpnp.moveaxis(b_h, -1, 0).reshape(b_orig_shape) + + return b_h + + def _is_copy_required(a, res_type): """ Determine if `a` needs to be copied before LU decomposition. @@ -197,6 +328,20 @@ def _is_copy_required(a, res_type): return False +def _map_trans_to_mkl(trans): + """Map SciPy-style trans code (0,1,2) to oneMKL transpose enum.""" + if not isinstance(trans, int): + raise TypeError("`trans` must be an integer") + + if trans == 0: + return li.Transpose.N + if trans == 1: + return li.Transpose.T + if trans == 2: + return li.Transpose.C + raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") + + def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): """ dpnp_lu_factor(a, overwrite_a=False, check_finite=True) @@ -307,18 +452,9 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): res_type = _common_type(lu, b) - # TODO: add broadcasting - if lu.shape[0] != b.shape[0]: - raise ValueError( - f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" - ) - if b.size == 0: return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) - if lu.ndim > 2: - raise NotImplementedError("Batched matrices are not supported") - if check_finite: if not dpnp.isfinite(lu).all(): raise ValueError( @@ -331,6 +467,16 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): "Right-hand side array must not contain infs or NaNs" ) + if lu.ndim > 2: + # SciPy always copies each 2D slice, + # so `overwrite_b` is ignored here + return _batched_lu_solve(lu, piv, b, trans=trans, res_type=res_type) + + if lu.shape[0] != b.shape[0]: + raise ValueError( + f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" + ) + lu_usm_arr = dpnp.get_usm_ndarray(lu) b_usm_arr = dpnp.get_usm_ndarray(b) @@ -377,18 +523,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): b_h = b dep_evs = [lu_copy_ev] - if not isinstance(trans, int): - raise TypeError("`trans` must be an integer") - - # Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums - if trans == 0: - trans_mkl = li.Transpose.N - elif trans == 1: - trans_mkl = li.Transpose.T - elif trans == 2: - trans_mkl = li.Transpose.C - else: - raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") + trans_mkl = _map_trans_to_mkl(trans) # Call the LAPACK extension function _getrs # to solve the system of linear equations with an LU-factored diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 8a6055d1915b..55031706ea37 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2421,8 +2421,8 @@ def test_lu_solve_batched(self, a_shape, b_shape, dtype, order): b_np = generate_random_numpy_array(b_shape, dtype, order) b_dp = dpnp.array(b_np, order=order) - lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) - x = dpnp.linalg.lu_solve( + lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.scipy.linalg.lu_solve( (lu, piv), b_dp, overwrite_b=True, check_finite=False ) @@ -2449,10 +2449,10 @@ def test_trans(self, trans, order, dtype): generate_random_numpy_array(b_shape, dtype, order), order=order ) - lu, piv = dpnp.linalg.lu_factor( + lu, piv = dpnp.scipy.linalg.lu_factor( a_dp, overwrite_a=False, check_finite=False ) - x = dpnp.linalg.lu_solve( + x = dpnp.scipy.linalg.lu_solve( (lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False ) @@ -2471,7 +2471,7 @@ def test_overwrite(self, dtype, order): a_np = self._make_nonsingular_nd_np((2, 4, 4), dtype, order) a_dp = dpnp.array(a_np, order=order) - lu, piv = dpnp.linalg.lu_factor( + lu, piv = dpnp.scipy.linalg.lu_factor( a_dp, overwrite_a=False, check_finite=False ) @@ -2479,7 +2479,7 @@ def test_overwrite(self, dtype, order): generate_random_numpy_array((2, 4, 2), dtype, "F"), order="F" ) b_dp_orig = b_dp.copy() - x = dpnp.linalg.lu_solve( + x = dpnp.scipy.linalg.lu_solve( (lu, piv), b_dp, overwrite_b=True, check_finite=False ) @@ -2504,8 +2504,8 @@ def test_strided(self): ) b_dp = rhs_full[::2, :, ::-1] - lu, piv = dpnp.linalg.lu_factor(a_stride, check_finite=False) - x = dpnp.linalg.lu_solve( + lu, piv = dpnp.scipy.linalg.lu_factor(a_stride, check_finite=False) + x = dpnp.scipy.linalg.lu_solve( (lu, piv), b_dp, overwrite_b=False, check_finite=False ) @@ -2526,8 +2526,8 @@ def test_diff_type(self, dtype_a, dtype_b, b_shape): b_np = generate_random_numpy_array(b_shape, dtype_b, "F") b_dp = dpnp.array(b_np, order="F") - lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) - x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.scipy.linalg.lu_solve((lu, piv), b_dp, check_finite=False) exp_shape = (B, n, k) assert x.shape == exp_shape @@ -2549,8 +2549,8 @@ def test_empty_shapes(self, a_shape, b_shape): a = dpnp.empty(a_shape, dtype=dpnp.default_float_type(), order="F") b = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") - lu, piv = dpnp.linalg.lu_factor(a, check_finite=False) - x = dpnp.linalg.lu_solve((lu, piv), b, check_finite=False) + lu, piv = dpnp.scipy.linalg.lu_factor(a, check_finite=False) + x = dpnp.scipy.linalg.lu_solve((lu, piv), b, check_finite=False) assert x.shape == b_shape @@ -2560,13 +2560,13 @@ def test_check_finite_raises(self): (B, n, n), dpnp.default_float_type(), "F" ) a_dp = dpnp.array(a_np, order="F") - lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False) b_bad = dpnp.ones((B, n), dtype=dpnp.default_float_type(), order="F") b_bad[1, 0] = dpnp.nan assert_raises( ValueError, - dpnp.linalg.lu_solve, + dpnp.scipy.linalg.lu_solve, (lu, piv), b_bad, check_finite=True, @@ -2593,9 +2593,9 @@ def test_invalid_shapes(self, a_shape, b_shape): generate_random_numpy_array(b_shape, dtype, "F"), order="F" ) - lu, piv = dpnp.linalg.lu_factor(a, check_finite=False) + lu, piv = dpnp.scipy.linalg.lu_factor(a, check_finite=False) with pytest.raises(ValueError): - dpnp.linalg.lu_solve((lu, piv), b, check_finite=False) + dpnp.scipy.linalg.lu_solve((lu, piv), b, check_finite=False) class TestMatrixPower: From e760aa20829708cbe7ce5b8625a19c0d5e86f103 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 12:21:03 -0700 Subject: [PATCH 10/18] Update lu_solve docs --- dpnp/scipy/linalg/_decomp_lu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpnp/scipy/linalg/_decomp_lu.py b/dpnp/scipy/linalg/_decomp_lu.py index d149c16edfcd..64e18994d772 100644 --- a/dpnp/scipy/linalg/_decomp_lu.py +++ b/dpnp/scipy/linalg/_decomp_lu.py @@ -132,7 +132,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): Parameters ---------- lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays} - LU factorization of matrix `a` (M, M) together with pivot indices. + LU factorization of matrix `a` (..., M, M) together with pivot indices. b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} Right-hand side trans : {0, 1, 2} , optional @@ -160,7 +160,7 @@ def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): Returns ------- - x : {(M,), (M, K)} dpnp.ndarray + x : {(M,), (..., M, K)} dpnp.ndarray Solution to the system Warning From 38689a351df036ed668fb3e940eafcd0ad50301f Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 9 Oct 2025 12:23:59 -0700 Subject: [PATCH 11/18] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bab2d17db60..eba8850b18d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added * Added the docstrings to `dpnp.linalg.LinAlgError` exception [#2613](https://github.com/IntelPython/dpnp/pull/2613) +* Added implementation of `dpnp.linalg.lu_solve` for batch inputs (SciPy-compatible) [#2618](https://github.com/IntelPython/dpnp/pull/2618) ### Changed From 240f97e2bb323f5ca6bf8d09021ee6d5e72d09e0 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 10 Oct 2025 04:41:28 -0700 Subject: [PATCH 12/18] Apply text-related comments --- CHANGELOG.md | 2 +- dpnp/backend/extensions/lapack/getrs_batch.cpp | 3 +++ dpnp/scipy/linalg/_utils.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eba8850b18d0..d7e7d41c3308 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added * Added the docstrings to `dpnp.linalg.LinAlgError` exception [#2613](https://github.com/IntelPython/dpnp/pull/2613) -* Added implementation of `dpnp.linalg.lu_solve` for batch inputs (SciPy-compatible) [#2618](https://github.com/IntelPython/dpnp/pull/2618) +* Added implementation of `dpnp.linalg.lu_solve` for batch inputs (SciPy-compatible) [#2619](https://github.com/IntelPython/dpnp/pull/2619) ### Changed diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp index b2905e82690c..e3df26325585 100644 --- a/dpnp/backend/extensions/lapack/getrs_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -9,6 +9,9 @@ // - Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation // and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE diff --git a/dpnp/scipy/linalg/_utils.py b/dpnp/scipy/linalg/_utils.py index 609bebac4610..fb03d032fe00 100644 --- a/dpnp/scipy/linalg/_utils.py +++ b/dpnp/scipy/linalg/_utils.py @@ -487,7 +487,7 @@ def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - # oneMKL LAPACK getrs overwrites `lu`. + # oneMKL LAPACK getrs_batch overwrites `lu`. lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) # use DPCTL tensor function to fill the сopy of the input array From 01078bf0a6216c12f170b4368fbb3fc92d77ebcd Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 10 Oct 2025 06:25:54 -0700 Subject: [PATCH 13/18] Const correctness for getrs and getrs_batch params --- dpnp/backend/extensions/lapack/getrs.cpp | 14 +++++----- dpnp/backend/extensions/lapack/getrs.hpp | 16 +++++------ .../backend/extensions/lapack/getrs_batch.cpp | 28 +++++++++---------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index 4112a7c95ae7..81c1326e4ebf 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -48,14 +48,14 @@ namespace type_utils = dpctl::tensor::type_utils; using ext::common::init_dispatch_vector; typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &, - oneapi::mkl::transpose, + const oneapi::mkl::transpose, const std::int64_t, const std::int64_t, char *, - std::int64_t, - std::int64_t *, + const std::int64_t, + const std::int64_t *, char *, - std::int64_t, + const std::int64_t, std::vector &, const std::vector &); @@ -67,10 +67,10 @@ static sycl::event getrs_impl(sycl::queue &exec_q, const std::int64_t n, const std::int64_t nrhs, char *in_a, - std::int64_t lda, - std::int64_t *ipiv, + const std::int64_t lda, + const std::int64_t *ipiv, char *in_b, - std::int64_t ldb, + const std::int64_t ldb, std::vector &host_task_events, const std::vector &depends) { diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 9babe85ae90d..8e0c5bcdd5f0 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -37,7 +37,7 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, - oneapi::mkl::transpose trans, + const oneapi::mkl::transpose trans, const std::vector &depends = {}); extern std::pair @@ -45,13 +45,13 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, - oneapi::mkl::transpose trans, - std::int64_t n, - std::int64_t nrhs, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t stride_b, - std::int64_t batch_size, + const oneapi::mkl::transpose trans, + const std::int64_t n, + const std::int64_t nrhs, + const std::int64_t stride_a, + const std::int64_t stride_ipiv, + const std::int64_t stride_b, + const std::int64_t batch_size, const std::vector &depends = {}); extern void init_getrs_dispatch_vector(void); diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp index e3df26325585..77c5639a6b6f 100644 --- a/dpnp/backend/extensions/lapack/getrs_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -56,14 +56,14 @@ typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( const std::int64_t, // n const std::int64_t, // nrhs char *, // a - std::int64_t, // lda - std::int64_t, // stride_a - std::int64_t *, // ipiv - std::int64_t, // stride_ipiv + const std::int64_t, // lda + const std::int64_t, // stride_a + const std::int64_t *, // ipiv + const std::int64_t, // stride_ipiv char *, // b - std::int64_t, // ldb - std::int64_t, // stride_b - std::int64_t, // batch_size + const std::int64_t, // ldb + const std::int64_t, // stride_b + const std::int64_t, // batch_size std::vector &, const std::vector &); @@ -75,14 +75,14 @@ static sycl::event getrs_batch_impl(sycl::queue &exec_q, const std::int64_t n, const std::int64_t nrhs, char *in_a, - std::int64_t lda, - std::int64_t stride_a, - std::int64_t *ipiv, - std::int64_t stride_ipiv, + const std::int64_t lda, + const std::int64_t stride_a, + const std::int64_t *ipiv, + const std::int64_t stride_ipiv, char *in_b, - std::int64_t ldb, - std::int64_t stride_b, - std::int64_t batch_size, + const std::int64_t ldb, + const std::int64_t stride_b, + const std::int64_t batch_size, std::vector &host_task_events, const std::vector &depends) { From 6b3f331dcd67c604becbfdcb1d515e124c71dab6 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 10 Oct 2025 06:40:06 -0700 Subject: [PATCH 14/18] Apply remarks for _utils.py --- dpnp/scipy/linalg/_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dpnp/scipy/linalg/_utils.py b/dpnp/scipy/linalg/_utils.py index fb03d032fe00..5aa1588da5cb 100644 --- a/dpnp/scipy/linalg/_utils.py +++ b/dpnp/scipy/linalg/_utils.py @@ -215,9 +215,6 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0): """Solve a batched equation system (SciPy-compatible behavior).""" res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) - if b.size == 0: - return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) - b_ndim = b.ndim lu, b = _align_lu_solve_broadcast(lu, b) @@ -257,7 +254,7 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0): _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - # oneMKL LAPACK getrs overwrites `lu`. + # oneMKL LAPACK getrs_batch overwrites `lu` lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) # use DPCTL tensor function to fill the сopy of the input array @@ -270,6 +267,8 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0): ) _manager.add_event_pair(ht_ev, lu_copy_ev) + # oneMKL LAPACK getrs_batch overwrites `b` and assumes fortran-like array + # as input b_h = dpnp.empty_like(b, order="F", dtype=res_type, usm_type=res_usm_type) ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=b_usm_arr, From 83a8c851d7b683b67a11a8881971b03dcf61fccb Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 10 Oct 2025 06:56:42 -0700 Subject: [PATCH 15/18] Apply remarks for getrf_batch.cpp --- .../backend/extensions/lapack/getrs_batch.cpp | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp index 77c5639a6b6f..a526efd6c079 100644 --- a/dpnp/backend/extensions/lapack/getrs_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -33,10 +33,12 @@ #include #include +// utils extension header +#include "ext/common.hpp" + // dpctl tensor headers #include "utils/memory_overlap.hpp" #include "utils/sycl_alloc_utils.hpp" -#include "utils/type_dispatch.hpp" #include "utils/type_utils.hpp" #include "getrs.hpp" @@ -48,7 +50,8 @@ namespace dpnp::extensions::lapack namespace mkl_lapack = oneapi::mkl::lapack; namespace py = pybind11; namespace type_utils = dpctl::tensor::type_utils; -namespace td_ns = dpctl::tensor::type_dispatch; + +using ext::common::init_dispatch_vector; typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( sycl::queue &, @@ -67,7 +70,8 @@ typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( std::vector &, const std::vector &); -static getrs_batch_impl_fn_ptr_t getrs_batch_dispatch_vector[td_ns::num_types]; +static getrs_batch_impl_fn_ptr_t + getrs_batch_dispatch_vector[dpctl_td_ns::num_types]; template static sycl::event getrs_batch_impl(sycl::queue &exec_q, @@ -225,7 +229,7 @@ std::pair std::to_string(b_array_nd) + ", but an array with ndim >= 2 is expected"); } - if (ipiv_array_nd < 1) { + if (ipiv_array_nd < 2) { throw py::value_error("The array of pivot indices has ndim=" + std::to_string(ipiv_array_nd) + ", but an array with ndim >= 2 is expected"); @@ -280,7 +284,7 @@ std::pair "must be contiguous"); } - auto array_types = td_ns::usm_ndarray_types(); + auto array_types = dpctl_td_ns::usm_ndarray_types(); int a_array_type_id = array_types.typenum_to_lookup_id(a_array.get_typenum()); int b_array_type_id = @@ -299,11 +303,11 @@ std::pair "of the input matrix"); } - auto ipiv_types = td_ns::usm_ndarray_types(); + auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); int ipiv_array_type_id = ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); - if (ipiv_array_type_id != static_cast(td_ns::typenum_t::INT64)) { + if (ipiv_array_type_id != static_cast(dpctl_td_ns::typenum_t::INT64)) { throw py::value_error("The type of 'ipiv_array' must be int64"); } @@ -343,9 +347,7 @@ struct GetrsBatchContigFactory void init_getrs_batch_dispatch_vector(void) { - td_ns::DispatchVectorBuilder - contig; - contig.populate_dispatch_vector(getrs_batch_dispatch_vector); + init_dispatch_vector( + getrs_batch_dispatch_vector); } } // namespace dpnp::extensions::lapack From cddc32107c1c65e9820e6072825a94e38e25a65a Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 10 Oct 2025 08:12:13 -0700 Subject: [PATCH 16/18] Fix ipiv_contig check in getrs --- dpnp/backend/extensions/lapack/getrs.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index 1802a92b1bc4..2186620ffb99 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -234,7 +234,7 @@ std::pair throw py::value_error("The right-hand sides array " "must be F-contiguous"); } - if (!is_ipiv_array_c_contig || !is_ipiv_array_f_contig) { + if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) { throw py::value_error("The array of pivot indices " "must be contiguous"); } From 3e7ed555b505498e6a1aa2006022f059f89bc822 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 13 Oct 2025 09:09:17 -0700 Subject: [PATCH 17/18] Remove const from ipiv to match oneMath signature --- dpnp/backend/extensions/lapack/getrs.cpp | 4 ++-- dpnp/backend/extensions/lapack/getrs_batch.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index 2186620ffb99..8108afd97003 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -56,7 +56,7 @@ typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &, const std::int64_t, char *, const std::int64_t, - const std::int64_t *, + std::int64_t *, char *, const std::int64_t, std::vector &, @@ -71,7 +71,7 @@ static sycl::event getrs_impl(sycl::queue &exec_q, const std::int64_t nrhs, char *in_a, const std::int64_t lda, - const std::int64_t *ipiv, + std::int64_t *ipiv, char *in_b, const std::int64_t ldb, std::vector &host_task_events, diff --git a/dpnp/backend/extensions/lapack/getrs_batch.cpp b/dpnp/backend/extensions/lapack/getrs_batch.cpp index a526efd6c079..9fc6ce1a5dfc 100644 --- a/dpnp/backend/extensions/lapack/getrs_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrs_batch.cpp @@ -61,7 +61,7 @@ typedef sycl::event (*getrs_batch_impl_fn_ptr_t)( char *, // a const std::int64_t, // lda const std::int64_t, // stride_a - const std::int64_t *, // ipiv + std::int64_t *, // ipiv const std::int64_t, // stride_ipiv char *, // b const std::int64_t, // ldb @@ -81,7 +81,7 @@ static sycl::event getrs_batch_impl(sycl::queue &exec_q, char *in_a, const std::int64_t lda, const std::int64_t stride_a, - const std::int64_t *ipiv, + std::int64_t *ipiv, const std::int64_t stride_ipiv, char *in_b, const std::int64_t ldb, From 183293bc64068082d3086859293aecf329a2e61f Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 14 Oct 2025 03:00:28 -0700 Subject: [PATCH 18/18] Apply remarks --- dpnp/scipy/linalg/_utils.py | 8 ++++---- dpnp/tests/test_linalg.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/dpnp/scipy/linalg/_utils.py b/dpnp/scipy/linalg/_utils.py index 27a7ffa58bdd..eab46b8e92dc 100644 --- a/dpnp/scipy/linalg/_utils.py +++ b/dpnp/scipy/linalg/_utils.py @@ -218,12 +218,12 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0): """Solve a batched equation system (SciPy-compatible behavior).""" res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) - b_ndim = b.ndim + b_ndim_orig = b.ndim lu, b = _align_lu_solve_broadcast(lu, b) n = lu.shape[-1] - nrhs = b.shape[-1] if b_ndim > 1 else 1 + nrhs = b.shape[-1] if b_ndim_orig > 1 else 1 # get 3d input arrays by reshape if lu.ndim > 3: @@ -235,11 +235,11 @@ def _batched_lu_solve(lu, piv, b, res_type, trans=0): # Move batch axis to the end (n, n, batch) in Fortran order: # required by getrs_batch - # and ensures each a[..., i] is F-contiguous for getrs_batch + # and ensures each lu[..., i] is F-contiguous for getrs_batch lu = dpnp.moveaxis(lu, 0, -1) b_orig_shape = b.shape - if b.ndim > 2: + if b.ndim > 3: b = dpnp.reshape(b, (-1, n, nrhs)) # Move batch axis to the end (n, nrhs, batch) in Fortran order: diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 55031706ea37..c5f6d3629515 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2339,8 +2339,7 @@ def test_empty_shapes(self, a_shape, b_shape): n = a_shape[0] if n > 0: - for i in range(n): - a_dp[i, i] = a_dp.dtype.type(1.0) + dpnp.fill_diagonal(a_dp, a_dp.dtype.type(1.0)) b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") lu, piv = dpnp.scipy.linalg.lu_factor(a_dp, check_finite=False)