-
Notifications
You must be signed in to change notification settings - Fork 23
Implement dpnp.linalg.lu_solve()
2D inputs
#2575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vlad-perevezentsev
wants to merge
16
commits into
master
Choose a base branch
from
impl_lu_solve_2D
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+455
−0
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
4223c72
Pass trans_code to getrs in dpnp_solve()
vlad-perevezentsev 80ce50c
Remove TODO
vlad-perevezentsev af0ab7d
Implement of dpnp.linalg.lu_solve for 2D inputs
vlad-perevezentsev 17b11ae
Add dpnp.linalg.lu_solve to generated docs
vlad-perevezentsev b10a8d6
Add TestLuSolve to test_linalg.py
vlad-perevezentsev 2021f77
Add sycl_queue and usm_type tests
vlad-perevezentsev be2725a
Update doc/comment lines
vlad-perevezentsev 1e09cb7
Update dependency logic
vlad-perevezentsev 9345b7b
Add trans code handling
vlad-perevezentsev 687006f
Fix docs for lu:must be square
vlad-perevezentsev b1aed58
Merge master into impl_lu_solve_2D
vlad-perevezentsev 9aaff82
Update changelog
vlad-perevezentsev 23ad15d
Apply docs remarks
vlad-perevezentsev 82de136
Apply remarks
vlad-perevezentsev e586075
Add assert on USM data pointer to tests
vlad-perevezentsev 7d1fd0b
Update data inputs for test_usm_type
vlad-perevezentsev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,7 @@ | |
dpnp_inv, | ||
dpnp_lstsq, | ||
dpnp_lu_factor, | ||
dpnp_lu_solve, | ||
dpnp_matrix_power, | ||
dpnp_matrix_rank, | ||
dpnp_multi_dot, | ||
|
@@ -81,6 +82,7 @@ | |
"inv", | ||
"lstsq", | ||
"lu_factor", | ||
"lu_solve", | ||
"matmul", | ||
"matrix_norm", | ||
"matrix_power", | ||
|
@@ -966,6 +968,81 @@ def lu_factor(a, overwrite_a=False, check_finite=True): | |
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite) | ||
|
||
|
||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): | ||
""" | ||
Solve a linear system, :math:`a x = b`, given the LU factorization of `a`. | ||
|
||
For full documentation refer to :obj:`scipy.linalg.lu_solve`. | ||
|
||
Parameters | ||
---------- | ||
lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays} | ||
LU factorization of matrix `a` (M, M) together with pivot indices. | ||
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} | ||
Right-hand side | ||
trans : {0, 1, 2} , optional | ||
Type of system to solve: | ||
|
||
===== ================= | ||
trans system | ||
===== ================= | ||
0 :math:`a x = b` | ||
1 :math:`a^T x = b` | ||
2 :math:`a^H x = b` | ||
===== ================= | ||
|
||
Default: ``0``. | ||
overwrite_b : {None, bool}, optional | ||
Whether to overwrite data in `b` (may increase performance). | ||
|
||
Default: ``False``. | ||
check_finite : {None, bool}, optional | ||
Whether to check that the input matrix contains only finite numbers. | ||
Disabling may give a performance gain, but may result in problems | ||
(crashes, non-termination) if the inputs do contain infinities or NaNs. | ||
|
||
Default: ``True``. | ||
|
||
Returns | ||
------- | ||
x : {(M,), (M, K)} dpnp.ndarray | ||
Solution to the system | ||
|
||
Warning | ||
------- | ||
This function synchronizes in order to validate array elements | ||
when ``check_finite=True``. | ||
|
||
vlad-perevezentsev marked this conversation as resolved.
Show resolved
Hide resolved
|
||
See Also | ||
-------- | ||
:obj:`dpnp.linalg.lu_factor` : LU factorize a matrix. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add |
||
|
||
Examples | ||
-------- | ||
>>> import dpnp as np | ||
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) | ||
>>> b = np.array([1, 1, 1, 1]) | ||
>>> lu, piv = np.linalg.lu_factor(A) | ||
>>> x = np.linalg.lu_solve((lu, piv), b) | ||
>>> np.allclose(A @ x - b, np.zeros((4,))) | ||
array(True) | ||
|
||
""" | ||
|
||
(lu, piv) = lu_and_piv | ||
dpnp.check_supported_arrays_type(lu, piv, b) | ||
assert_stacked_2d(lu) | ||
|
||
return dpnp_lu_solve( | ||
lu, | ||
piv, | ||
b, | ||
trans=trans, | ||
overwrite_b=overwrite_b, | ||
check_finite=check_finite, | ||
) | ||
|
||
|
||
def matmul(x1, x2, /): | ||
""" | ||
Computes the matrix product. | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2477,6 +2477,134 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): | |
return (a_h, ipiv_h) | ||
|
||
|
||
def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): | ||
""" | ||
dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True) | ||
|
||
Solve an equation system (SciPy-compatible behavior). | ||
|
||
This function mimics the behavior of `scipy.linalg.lu_solve` including | ||
support for `trans`, `overwrite_b`, `check_finite`, | ||
and 0-based pivot indexing. | ||
|
||
""" | ||
|
||
res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) | ||
|
||
res_type = _common_type(lu, b) | ||
|
||
# TODO: add broadcasting | ||
if lu.shape[0] != b.shape[0]: | ||
raise ValueError( | ||
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" | ||
) | ||
|
||
if b.size == 0: | ||
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) | ||
|
||
if lu.ndim > 2: | ||
raise NotImplementedError("Batched matrices are not supported") | ||
|
||
if check_finite: | ||
if not dpnp.isfinite(lu).all(): | ||
raise ValueError( | ||
"LU factorization array must not contain infs or NaNs.\n" | ||
"Note that when a singular matrix is given, unlike " | ||
"dpnp.linalg.lu_factor returns an array containing NaN." | ||
) | ||
if not dpnp.isfinite(b).all(): | ||
raise ValueError( | ||
"Right-hand side array must not contain infs or NaNs" | ||
) | ||
|
||
lu_usm_arr = dpnp.get_usm_ndarray(lu) | ||
piv_usm_arr = dpnp.get_usm_ndarray(piv) | ||
b_usm_arr = dpnp.get_usm_ndarray(b) | ||
|
||
_manager = dpu.SequentialOrderManager[exec_q] | ||
dep_evs = _manager.submitted_events | ||
|
||
# oneMKL LAPACK getrf overwrites `lu`. | ||
lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) | ||
|
||
# use DPCTL tensor function to fill the сopy of the input array | ||
# from the input array | ||
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( | ||
src=lu_usm_arr, | ||
dst=lu_h.get_array(), | ||
sycl_queue=lu.sycl_queue, | ||
depends=dep_evs, | ||
) | ||
_manager.add_event_pair(ht_ev, lu_copy_ev) | ||
|
||
# oneMKL LAPACK getrf overwrites `piv`. | ||
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) | ||
|
||
# use DPCTL tensor function to fill the сopy of the pivot array | ||
# from the pivot array | ||
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( | ||
src=piv_usm_arr, | ||
dst=piv_h.get_array(), | ||
sycl_queue=piv.sycl_queue, | ||
depends=dep_evs, | ||
) | ||
_manager.add_event_pair(ht_ev, piv_copy_ev) | ||
|
||
# SciPy-compatible behavior | ||
# Copy is required if: | ||
# - overwrite_b is False (always copy), | ||
# - dtype mismatch, | ||
# - not F-contiguous, | ||
# - not writeable | ||
if not overwrite_b or _is_copy_required(b, res_type): | ||
b_h = dpnp.empty_like( | ||
b, order="F", dtype=res_type, usm_type=res_usm_type | ||
) | ||
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( | ||
src=b_usm_arr, | ||
dst=b_h.get_array(), | ||
sycl_queue=b.sycl_queue, | ||
depends=dep_evs, | ||
) | ||
_manager.add_event_pair(ht_ev, b_copy_ev) | ||
dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev] | ||
else: | ||
# input is suitable for in-place modification | ||
b_h = b | ||
dep_evs = [lu_copy_ev, piv_copy_ev] | ||
|
||
# MKL lapack uses 1-origin while SciPy uses 0-origin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess SciPy also uses MKL to call |
||
piv_h += 1 | ||
|
||
if not isinstance(trans, int): | ||
raise TypeError("`trans` must be an integer") | ||
|
||
# Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums | ||
if trans == 0: | ||
trans_mkl = li.Transpose.N | ||
elif trans == 1: | ||
trans_mkl = li.Transpose.T | ||
elif trans == 2: | ||
trans_mkl = li.Transpose.C | ||
else: | ||
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)") | ||
|
||
# Call the LAPACK extension function _getrs | ||
# to solve the system of linear equations with an LU-factored | ||
# coefficient square matrix, with multiple right-hand sides. | ||
ht_ev, getrs_ev = li._getrs( | ||
exec_q, | ||
lu_h.get_array(), | ||
piv_h.get_array(), | ||
b_h.get_array(), | ||
trans_mkl, | ||
depends=dep_evs, | ||
) | ||
_manager.add_event_pair(ht_ev, getrs_ev) | ||
|
||
return b_h | ||
|
||
|
||
def dpnp_matrix_power(a, n): | ||
""" | ||
dpnp_matrix_power(a, n) | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.