Skip to content

Commit 958537a

Browse files
Merge 6e7b9f7 into d975818
2 parents d975818 + 6e7b9f7 commit 958537a

File tree

4 files changed

+19
-15
lines changed

4 files changed

+19
-15
lines changed

doc/reference/linalg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Decompositions
4343
dpnp.linalg.cholesky
4444
dpnp.linalg.outer
4545
dpnp.linalg.qr
46+
dpnp.linalg.lu_factor
4647
dpnp.linalg.svd
4748
dpnp.linalg.svdvals
4849

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -919,19 +919,22 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
919919
a : (M, N) {dpnp.ndarray, usm_ndarray}
920920
Input array to decompose.
921921
overwrite_a : {None, bool}, optional
922-
Whether to overwrite data in `a` (may increase performance)
922+
Whether to overwrite data in `a` (may increase performance).
923+
923924
Default: ``False``.
924925
check_finite : {None, bool}, optional
925926
Whether to check that the input matrix contains only finite numbers.
926927
Disabling may give a performance gain, but may result in problems
927928
(crashes, non-termination) if the inputs do contain infinities or NaNs.
928929
930+
Default: ``True``.
931+
929932
Returns
930933
-------
931-
lu :(M, N) dpnp.ndarray
934+
lu : (M, N) dpnp.ndarray
932935
Matrix containing U in its upper triangle, and L in its lower triangle.
933936
The unit diagonal elements of L are not stored.
934-
piv (K, ): dpnp.ndarray
937+
piv : (K, ) dpnp.ndarray
935938
Pivot indices representing the permutation matrix P:
936939
row i of matrix was interchanged with row piv[i].
937940
``K = min(M, N)``.

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
481481
if any(dev_info_h):
482482
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
483483
warn(
484-
f"Diagonal number {diag_nums} are exactly zero. "
484+
f"Diagonal numbers {diag_nums} are exactly zero. "
485485
"Singular matrix.",
486486
RuntimeWarning,
487487
stacklevel=2,
@@ -2493,14 +2493,14 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
24932493
a_h.get_array(),
24942494
ipiv_h.get_array(),
24952495
dev_info_h,
2496-
depends=[copy_ev] if copy_ev is not None else [],
2496+
depends=[copy_ev] if copy_ev is not None else _manager.submitted_events,
24972497
)
24982498
_manager.add_event_pair(ht_ev, getrf_ev)
24992499

25002500
if any(dev_info_h):
25012501
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
25022502
warn(
2503-
f"Diagonal number {diag_nums} are exactly zero. Singular matrix.",
2503+
f"Diagonal number {diag_nums} is exactly zero. Singular matrix.",
25042504
RuntimeWarning,
25052505
stacklevel=2,
25062506
)

dpnp/tests/test_linalg.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,7 @@ def test_lu_factor(self, shape, order, dtype):
19111911
A_cast = a_dp.astype(LU.dtype, copy=False)
19121912
PA = self._apply_pivots_rows(A_cast, piv)
19131913

1914-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1914+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
19151915

19161916
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
19171917
def test_overwrite_inplace(self, dtype):
@@ -1928,7 +1928,7 @@ def test_overwrite_inplace(self, dtype):
19281928
PA = self._apply_pivots_rows(a_dp_orig, piv)
19291929
LU = L @ U
19301930

1931-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1931+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
19321932

19331933
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
19341934
def test_overwrite_copy(self, dtype):
@@ -1945,7 +1945,7 @@ def test_overwrite_copy(self, dtype):
19451945
PA = self._apply_pivots_rows(a_dp_orig, piv)
19461946
LU = L @ U
19471947

1948-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1948+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
19491949

19501950
def test_overwrite_copy_special(self):
19511951
# F-contig but dtype != res_type
@@ -1972,7 +1972,7 @@ def test_overwrite_copy_special(self):
19721972
a_orig.astype(L.dtype, copy=False), piv
19731973
)
19741974
LU = L @ U
1975-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
1975+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
19761976

19771977
@pytest.mark.parametrize("shape", [(0, 0), (0, 2), (2, 0)])
19781978
def test_empty_inputs(self, shape):
@@ -2003,7 +2003,7 @@ def test_strided(self, sl):
20032003
PA = self._apply_pivots_rows(a_dp, piv)
20042004
LU = L @ U
20052005

2006-
assert_allclose(LU, PA, rtol=1e-6, atol=1e-6)
2006+
assert dpnp.allclose(LU, PA, rtol=1e-6, atol=1e-6)
20072007

20082008
def test_singular_matrix(self):
20092009
a_dp = dpnp.array([[1.0, 2.0], [2.0, 4.0]])
@@ -2070,7 +2070,7 @@ def test_lu_factor_batched(self, shape, order, dtype):
20702070
L, U = self._split_lu(lu_3d[i], m, n)
20712071
A_cast = a_3d[i].astype(L.dtype, copy=False)
20722072
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
2073-
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2073+
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
20742074

20752075
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
20762076
@pytest.mark.parametrize("order", ["C", "F"])
@@ -2082,7 +2082,7 @@ def test_overwrite(self, dtype, order):
20822082
)
20832083

20842084
assert lu is not a_dp
2085-
assert_allclose(a_dp, a_dp_orig)
2085+
assert dpnp.allclose(a_dp, a_dp_orig)
20862086

20872087
m = n = 2
20882088
lu_3d = lu.reshape((-1, m, n))
@@ -2092,7 +2092,7 @@ def test_overwrite(self, dtype, order):
20922092
L, U = self._split_lu(lu_3d[i], m, n)
20932093
A_cast = a_3d[i].astype(L.dtype, copy=False)
20942094
PA = self._apply_pivots_rows(A_cast, piv_2d[i])
2095-
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2095+
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
20962096

20972097
@pytest.mark.parametrize(
20982098
"shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)]
@@ -2119,7 +2119,7 @@ def test_strided(self):
21192119
PA = self._apply_pivots_rows(
21202120
a_stride[i].astype(L.dtype, copy=False), piv[i]
21212121
)
2122-
assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
2122+
assert dpnp.allclose(L @ U, PA, rtol=1e-6, atol=1e-6)
21232123

21242124
def test_singular_matrix(self):
21252125
a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type())

0 commit comments

Comments
 (0)