Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
36d21df
Implement getrs_batch lapack extension
vlad-perevezentsev Oct 7, 2025
4270823
Add more validation checks in getrs_batch
vlad-perevezentsev Oct 9, 2025
14feed3
Implement _batched_lu_solve
vlad-perevezentsev Oct 9, 2025
5e21a02
Add TestLuSolveBatched
vlad-perevezentsev Oct 9, 2025
a13cd88
Compute strides explicitly
vlad-perevezentsev Oct 9, 2025
b68d80b
Update test_lu_solve in test_usm_type/sycl_queue.py
vlad-perevezentsev Oct 9, 2025
f6d77fe
Add square-matrix assertion in lu_solve
vlad-perevezentsev Oct 9, 2025
eb8c58a
Update test_empty_shapes for lu_solve()
vlad-perevezentsev Oct 9, 2025
7464a25
Merge master into impl_lu_solve_batch
vlad-perevezentsev Oct 9, 2025
959f5f8
Move changes to new location(scipy folder)
vlad-perevezentsev Oct 9, 2025
e760aa2
Update lu_solve docs
vlad-perevezentsev Oct 9, 2025
38689a3
Update changelog
vlad-perevezentsev Oct 9, 2025
240f97e
Apply text-related comments
vlad-perevezentsev Oct 10, 2025
01078bf
Const correctness for getrs and getrs_batch params
vlad-perevezentsev Oct 10, 2025
6b3f331
Apply remarks for _utils.py
vlad-perevezentsev Oct 10, 2025
83a8c85
Apply remarks for getrf_batch.cpp
vlad-perevezentsev Oct 10, 2025
4208454
Merge master into impl_lu_solve_batch
vlad-perevezentsev Oct 10, 2025
cddc321
Fix ipiv_contig check in getrs
vlad-perevezentsev Oct 10, 2025
8b7aecb
Merge master into impl_lu_solve_batch
vlad-perevezentsev Oct 10, 2025
3e7ed55
Remove const from ipiv to match oneMath signature
vlad-perevezentsev Oct 13, 2025
ee24137
Merge master into impl_lu_solve_batch
vlad-perevezentsev Oct 13, 2025
183293b
Apply remarks
vlad-perevezentsev Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This release changes the license from `BSD-2-Clause` to `BSD-3-Clause`.
### Added

* Added the docstrings to `dpnp.linalg.LinAlgError` exception [#2613](https://github.com/IntelPython/dpnp/pull/2613)
* Added implementation of `dpnp.linalg.lu_solve` for batch inputs (SciPy-compatible) [#2619](https://github.com/IntelPython/dpnp/pull/2619)

### Changed

Expand Down
1 change: 1 addition & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrs_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/heevd_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp
Expand Down
12 changes: 6 additions & 6 deletions dpnp/backend/extensions/lapack/getrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ namespace type_utils = dpctl::tensor::type_utils;
using ext::common::init_dispatch_vector;

typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &,
oneapi::mkl::transpose,
const oneapi::mkl::transpose,
const std::int64_t,
const std::int64_t,
char *,
std::int64_t,
const std::int64_t,
std::int64_t *,
char *,
std::int64_t,
const std::int64_t,
std::vector<sycl::event> &,
const std::vector<sycl::event> &);

Expand All @@ -70,10 +70,10 @@ static sycl::event getrs_impl(sycl::queue &exec_q,
const std::int64_t n,
const std::int64_t nrhs,
char *in_a,
std::int64_t lda,
const std::int64_t lda,
std::int64_t *ipiv,
char *in_b,
std::int64_t ldb,
const std::int64_t ldb,
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event> &depends)
{
Expand Down Expand Up @@ -234,7 +234,7 @@ std::pair<sycl::event, sycl::event>
throw py::value_error("The right-hand sides array "
"must be F-contiguous");
}
if (!is_ipiv_array_c_contig || !is_ipiv_array_f_contig) {
if (!is_ipiv_array_c_contig && !is_ipiv_array_f_contig) {
throw py::value_error("The array of pivot indices "
"must be contiguous");
}
Expand Down
17 changes: 16 additions & 1 deletion dpnp/backend/extensions/lapack/getrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,23 @@ extern std::pair<sycl::event, sycl::event>
const dpctl::tensor::usm_ndarray &a_array,
const dpctl::tensor::usm_ndarray &ipiv_array,
const dpctl::tensor::usm_ndarray &b_array,
oneapi::mkl::transpose trans,
const oneapi::mkl::transpose trans,
const std::vector<sycl::event> &depends = {});

extern std::pair<sycl::event, sycl::event>
getrs_batch(sycl::queue &exec_q,
const dpctl::tensor::usm_ndarray &a_array,
const dpctl::tensor::usm_ndarray &ipiv_array,
const dpctl::tensor::usm_ndarray &b_array,
const oneapi::mkl::transpose trans,
const std::int64_t n,
const std::int64_t nrhs,
const std::int64_t stride_a,
const std::int64_t stride_ipiv,
const std::int64_t stride_b,
const std::int64_t batch_size,
const std::vector<sycl::event> &depends = {});

extern void init_getrs_dispatch_vector(void);
extern void init_getrs_batch_dispatch_vector(void);
} // namespace dpnp::extensions::lapack
Loading
Loading