Skip to content

Commit 28e86bb

Browse files
Merge a3d31c8 into 230ba6a
2 parents 230ba6a + a3d31c8 commit 28e86bb

File tree

7 files changed

+480
-128
lines changed

7 files changed

+480
-128
lines changed

dpnp/backend/extensions/lapack/getrf.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ namespace py = pybind11;
4444
namespace type_utils = dpctl::tensor::type_utils;
4545

4646
typedef sycl::event (*getrf_impl_fn_ptr_t)(sycl::queue &,
47+
const std::int64_t,
4748
const std::int64_t,
4849
char *,
4950
std::int64_t,
@@ -56,6 +57,7 @@ static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types];
5657

5758
template <typename T>
5859
static sycl::event getrf_impl(sycl::queue &exec_q,
60+
const std::int64_t m,
5961
const std::int64_t n,
6062
char *in_a,
6163
std::int64_t lda,
@@ -82,11 +84,11 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
8284

8385
getrf_event = mkl_lapack::getrf(
8486
exec_q,
85-
n, // The order of the square matrix A (0 ≤ n).
87+
m, // The number of rows in the input matrix A (0 ≤ m).
8688
// It must be a non-negative integer.
87-
n, // The number of columns in the square matrix A (0 ≤ n).
89+
n, // The number of columns in the input matrix A (0 ≤ n).
8890
// It must be a non-negative integer.
89-
a, // Pointer to the square matrix A (n x n).
91+
a, // Pointer to the input matrix A (n x n).
9092
lda, // The leading dimension of matrix A.
9193
// It must be at least max(1, n).
9294
ipiv, // Pointer to the output array of pivot indices.
@@ -99,7 +101,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
99101

100102
if (info < 0) {
101103
error_msg << "Parameter number " << -info
102-
<< " had an illegal value.";
104+
<< " had an illegal value";
103105
}
104106
else if (info == scratchpad_size && e.detail() != 0) {
105107
error_msg
@@ -168,13 +170,13 @@ std::pair<sycl::event, sycl::event>
168170
if (a_array_nd != 2) {
169171
throw py::value_error(
170172
"The input array has ndim=" + std::to_string(a_array_nd) +
171-
", but a 2-dimensional array is expected.");
173+
", but a 2-dimensional array is expected");
172174
}
173175

174176
if (ipiv_array_nd != 1) {
175177
throw py::value_error("The array of pivot indices has ndim=" +
176178
std::to_string(ipiv_array_nd) +
177-
", but a 1-dimensional array is expected.");
179+
", but a 1-dimensional array is expected");
178180
}
179181

180182
// check compatibility of execution queue and allocation queue
@@ -190,10 +192,12 @@ std::pair<sycl::event, sycl::event>
190192
}
191193

192194
bool is_a_array_c_contig = a_array.is_c_contiguous();
195+
bool is_a_array_f_contig = a_array.is_f_contiguous();
193196
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
194-
if (!is_a_array_c_contig) {
197+
if (!is_a_array_c_contig && !is_a_array_f_contig) {
195198
throw py::value_error("The input array "
196-
"must be C-contiguous");
199+
"must be must be either C-contiguous "
200+
"or F-contiguous");
197201
}
198202
if (!is_ipiv_array_c_contig) {
199203
throw py::value_error("The array of pivot indices "
@@ -208,27 +212,33 @@ std::pair<sycl::event, sycl::event>
208212
if (getrf_fn == nullptr) {
209213
throw py::value_error(
210214
"No getrf implementation defined for the provided type "
211-
"of the input matrix.");
215+
"of the input matrix");
212216
}
213217

214218
auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
215219
int ipiv_array_type_id =
216220
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());
217221

218222
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
219-
throw py::value_error("The type of 'ipiv_array' must be int64.");
223+
throw py::value_error("The type of 'ipiv_array' must be int64");
220224
}
221225

222-
const std::int64_t n = a_array.get_shape_raw()[0];
226+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
227+
const std::int64_t m = a_array_shape[0];
228+
const std::int64_t n = a_array_shape[1];
229+
const std::int64_t lda = std::max<size_t>(1UL, m);
230+
231+
if (ipiv_array.get_size() != std::min(m, n)) {
232+
throw py::value_error("The size of 'ipiv_array' must be min(m, n)");
233+
}
223234

224235
char *a_array_data = a_array.get_data();
225-
const std::int64_t lda = std::max<size_t>(1UL, n);
226236

227237
char *ipiv_array_data = ipiv_array.get_data();
228238
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);
229239

230240
std::vector<sycl::event> host_task_events;
231-
sycl::event getrf_ev = getrf_fn(exec_q, n, a_array_data, lda, d_ipiv,
241+
sycl::event getrf_ev = getrf_fn(exec_q, m, n, a_array_data, lda, d_ipiv,
232242
dev_info, host_task_events, depends);
233243

234244
sycl::event args_ev = dpctl::utils::keep_args_alive(

dpnp/backend/extensions/lapack/getrf_batch.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,12 @@ std::pair<sycl::event, sycl::event>
221221
}
222222

223223
bool is_a_array_c_contig = a_array.is_c_contiguous();
224+
bool is_a_array_f_contig = a_array.is_f_contiguous();
224225
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous();
225-
if (!is_a_array_c_contig) {
226+
if (!is_a_array_c_contig && !is_a_array_f_contig) {
226227
throw py::value_error("The input array "
227-
"must be C-contiguous");
228+
"must be must be either C-contiguous "
229+
"or F-contiguous");
228230
}
229231
if (!is_ipiv_array_c_contig) {
230232
throw py::value_error("The array of pivot indices "

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
dpnp_eigh,
5757
dpnp_inv,
5858
dpnp_lstsq,
59+
dpnp_lu_factor,
5960
dpnp_matrix_power,
6061
dpnp_matrix_rank,
6162
dpnp_multi_dot,
@@ -79,6 +80,7 @@
7980
"eigvalsh",
8081
"inv",
8182
"lstsq",
83+
"lu_factor",
8284
"matmul",
8385
"matrix_norm",
8486
"matrix_power",
@@ -901,6 +903,68 @@ def lstsq(a, b, rcond=None):
901903
return dpnp_lstsq(a, b, rcond=rcond)
902904

903905

906+
def lu_factor(a, overwrite_a=False, check_finite=True):
907+
"""
908+
Compute the pivoted LU decomposition of a matrix.
909+
910+
The decomposition is::
911+
912+
A = P @ L @ U
913+
914+
where `P` is a permutation matrix, `L` is lower triangular with unit
915+
diagonal elements, and `U` is upper triangular.
916+
917+
Parameters
918+
----------
919+
a : (M, N) {dpnp.ndarray, usm_ndarray}
920+
Input array to decompose.
921+
overwrite_a : {None, bool}, optional
922+
Whether to overwrite data in `a` (may increase performance)
923+
Default: ``False``.
924+
check_finite : {None, bool}, optional
925+
Whether to check that the input matrix contains only finite numbers.
926+
Disabling may give a performance gain, but may result in problems
927+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
928+
929+
Returns
930+
-------
931+
lu :(M, N) dpnp.ndarray
932+
Matrix containing U in its upper triangle, and L in its lower triangle.
933+
The unit diagonal elements of L are not stored.
934+
piv (K, ): dpnp.ndarray
935+
Pivot indices representing the permutation matrix P:
936+
row i of matrix was interchanged with row piv[i].
937+
``K = min(M, N)``.
938+
939+
Warning
940+
-------
941+
This function synchronizes in order to validate array elements
942+
when ``check_finite=True``.
943+
944+
Limitations
945+
-----------
946+
Only two-dimensional input matrices are supported.
947+
Otherwise, the function raises ``NotImplementedError`` exception.
948+
949+
Examples
950+
--------
951+
>>> import dpnp as np
952+
>>> a = np.array([[4., 3.], [6., 3.]])
953+
>>> lu, piv = np.linalg.lu_factor(a)
954+
>>> lu
955+
array([[6. , 3. ],
956+
[0.66666667, 1. ]])
957+
>>> piv
958+
array([1, 1])
959+
960+
"""
961+
962+
dpnp.check_supported_arrays_type(a)
963+
assert_stacked_2d(a)
964+
965+
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)
966+
967+
904968
def matmul(x1, x2, /):
905969
"""
906970
Computes the matrix product.

0 commit comments

Comments
 (0)