Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534)
* Added implementation of `dpnp.linalg.lu_factor` (SciPy-compatible) [#2557](https://github.com/IntelPython/dpnp/pull/2557), [#2565](https://github.com/IntelPython/dpnp/pull/2565)
* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550)
* Added implementation of `dpnp.linalg.lu_solve` for 2D inputs (SciPy-compatible) [#2575](https://github.com/IntelPython/dpnp/pull/2575)

### Changed

Expand Down
1 change: 1 addition & 0 deletions doc/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Solving linear equations
dpnp.linalg.solve
dpnp.linalg.tensorsolve
dpnp.linalg.lstsq
dpnp.linalg.lu_solve
dpnp.linalg.inv
dpnp.linalg.pinv
dpnp.linalg.tensorinv
Expand Down
77 changes: 77 additions & 0 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
dpnp_inv,
dpnp_lstsq,
dpnp_lu_factor,
dpnp_lu_solve,
dpnp_matrix_power,
dpnp_matrix_rank,
dpnp_multi_dot,
Expand All @@ -81,6 +82,7 @@
"inv",
"lstsq",
"lu_factor",
"lu_solve",
"matmul",
"matrix_norm",
"matrix_power",
Expand Down Expand Up @@ -966,6 +968,81 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)


def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""
Solve a linear system, :math:`a x = b`, given the LU factorization of `a`.

For full documentation refer to :obj:`scipy.linalg.lu_solve`.

Parameters
----------
lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays}
LU factorization of matrix `a` (M, M) together with pivot indices.
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
Right-hand side
trans : {0, 1, 2} , optional
Type of system to solve:

===== =================
trans system
===== =================
0 :math:`a x = b`
1 :math:`a^T x = b`
2 :math:`a^H x = b`
===== =================

Default: ``0``.
overwrite_b : {None, bool}, optional
Whether to overwrite data in `b` (may increase performance).

Default: ``False``.
check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Default: ``True``.

Returns
-------
x : {(M,), (M, K)} dpnp.ndarray
Solution to the system

Warning
-------
This function synchronizes in order to validate array elements
when ``check_finite=True``.

See Also
--------
:obj:`dpnp.linalg.lu_factor` : LU factorize a matrix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add See also section to lu_factor docstring also


Examples
--------
>>> import dpnp as np
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> b = np.array([1, 1, 1, 1])
>>> lu, piv = np.linalg.lu_factor(A)
>>> x = np.linalg.lu_solve((lu, piv), b)
>>> np.allclose(A @ x - b, np.zeros((4,)))
array(True)

"""

(lu, piv) = lu_and_piv
dpnp.check_supported_arrays_type(lu, piv, b)
assert_stacked_2d(lu)

return dpnp_lu_solve(
lu,
piv,
b,
trans=trans,
overwrite_b=overwrite_b,
check_finite=check_finite,
)


def matmul(x1, x2, /):
"""
Computes the matrix product.
Expand Down
128 changes: 128 additions & 0 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,6 +2477,134 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
return (a_h, ipiv_h)


def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
"""
dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True)

Solve an equation system (SciPy-compatible behavior).

This function mimics the behavior of `scipy.linalg.lu_solve` including
support for `trans`, `overwrite_b`, `check_finite`,
and 0-based pivot indexing.

"""

res_usm_type, exec_q = get_usm_allocations([lu, piv, b])

res_type = _common_type(lu, b)

# TODO: add broadcasting
if lu.shape[0] != b.shape[0]:
raise ValueError(
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
)

if b.size == 0:
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)

if lu.ndim > 2:
raise NotImplementedError("Batched matrices are not supported")

if check_finite:
if not dpnp.isfinite(lu).all():
raise ValueError(
"LU factorization array must not contain infs or NaNs.\n"
"Note that when a singular matrix is given, unlike "
"dpnp.linalg.lu_factor returns an array containing NaN."
)
if not dpnp.isfinite(b).all():
raise ValueError(
"Right-hand side array must not contain infs or NaNs"
)

lu_usm_arr = dpnp.get_usm_ndarray(lu)
piv_usm_arr = dpnp.get_usm_ndarray(piv)
b_usm_arr = dpnp.get_usm_ndarray(b)

_manager = dpu.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events

# oneMKL LAPACK getrf overwrites `lu`.
lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type)

# use DPCTL tensor function to fill the сopy of the input array
# from the input array
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=lu_usm_arr,
dst=lu_h.get_array(),
sycl_queue=lu.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, lu_copy_ev)

# oneMKL LAPACK getrf overwrites `piv`.
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type)

# use DPCTL tensor function to fill the сopy of the pivot array
# from the pivot array
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=piv_usm_arr,
dst=piv_h.get_array(),
sycl_queue=piv.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, piv_copy_ev)

# SciPy-compatible behavior
# Copy is required if:
# - overwrite_b is False (always copy),
# - dtype mismatch,
# - not F-contiguous,
# - not writeable
if not overwrite_b or _is_copy_required(b, res_type):
b_h = dpnp.empty_like(
b, order="F", dtype=res_type, usm_type=res_usm_type
)
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=b_usm_arr,
dst=b_h.get_array(),
sycl_queue=b.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, b_copy_ev)
dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev]
else:
# input is suitable for in-place modification
b_h = b
dep_evs = [lu_copy_ev, piv_copy_ev]

# MKL lapack uses 1-origin while SciPy uses 0-origin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess SciPy also uses MKL to call getrs, so it seems unclear for me.

piv_h += 1

if not isinstance(trans, int):
raise TypeError("`trans` must be an integer")

# Map SciPy-style trans codes (0, 1, 2) to MKL transpose enums
if trans == 0:
trans_mkl = li.Transpose.N
elif trans == 1:
trans_mkl = li.Transpose.T
elif trans == 2:
trans_mkl = li.Transpose.C
else:
raise ValueError("`trans` must be 0 (N), 1 (T), or 2 (C)")

# Call the LAPACK extension function _getrs
# to solve the system of linear equations with an LU-factored
# coefficient square matrix, with multiple right-hand sides.
ht_ev, getrs_ev = li._getrs(
exec_q,
lu_h.get_array(),
piv_h.get_array(),
b_h.get_array(),
trans_mkl,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, getrs_ev)

return b_h


def dpnp_matrix_power(a, n):
"""
dpnp_matrix_power(a, n)
Expand Down
Loading
Loading