@@ -44,6 +44,7 @@ namespace py = pybind11;
4444namespace type_utils = dpctl::tensor::type_utils;
4545
4646typedef sycl::event (*getrf_impl_fn_ptr_t )(sycl::queue &,
47+ const std::int64_t ,
4748 const std::int64_t ,
4849 char *,
4950 std::int64_t ,
@@ -56,6 +57,7 @@ static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types];
5657
5758template <typename T>
5859static sycl::event getrf_impl (sycl::queue &exec_q,
60+ const std::int64_t m,
5961 const std::int64_t n,
6062 char *in_a,
6163 std::int64_t lda,
@@ -82,11 +84,11 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
8284
8385 getrf_event = mkl_lapack::getrf (
8486 exec_q,
85- n , // The order of the square matrix A (0 ≤ n ).
87+ m , // The number of rows in the input matrix A (0 ≤ m ).
8688 // It must be a non-negative integer.
87- n, // The number of columns in the square matrix A (0 ≤ n).
89+ n, // The number of columns in the input matrix A (0 ≤ n).
8890 // It must be a non-negative integer.
89- a, // Pointer to the square matrix A (n x n).
91+ a, // Pointer to the input matrix A (n x n).
9092 lda, // The leading dimension of matrix A.
9193 // It must be at least max(1, n).
9294 ipiv, // Pointer to the output array of pivot indices.
@@ -99,7 +101,7 @@ static sycl::event getrf_impl(sycl::queue &exec_q,
99101
100102 if (info < 0 ) {
101103 error_msg << " Parameter number " << -info
102- << " had an illegal value. " ;
104+ << " had an illegal value" ;
103105 }
104106 else if (info == scratchpad_size && e.detail () != 0 ) {
105107 error_msg
@@ -168,13 +170,13 @@ std::pair<sycl::event, sycl::event>
168170 if (a_array_nd != 2 ) {
169171 throw py::value_error (
170172 " The input array has ndim=" + std::to_string (a_array_nd) +
171- " , but a 2-dimensional array is expected. " );
173+ " , but a 2-dimensional array is expected" );
172174 }
173175
174176 if (ipiv_array_nd != 1 ) {
175177 throw py::value_error (" The array of pivot indices has ndim=" +
176178 std::to_string (ipiv_array_nd) +
177- " , but a 1-dimensional array is expected. " );
179+ " , but a 1-dimensional array is expected" );
178180 }
179181
180182 // check compatibility of execution queue and allocation queue
@@ -190,10 +192,12 @@ std::pair<sycl::event, sycl::event>
190192 }
191193
192194 bool is_a_array_c_contig = a_array.is_c_contiguous ();
195+ bool is_a_array_f_contig = a_array.is_f_contiguous ();
193196 bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous ();
194- if (!is_a_array_c_contig) {
197+ if (!is_a_array_c_contig && !is_a_array_f_contig ) {
195198 throw py::value_error (" The input array "
196- " must be C-contiguous" );
199+ " must be must be either C-contiguous "
200+ " or F-contiguous" );
197201 }
198202 if (!is_ipiv_array_c_contig) {
199203 throw py::value_error (" The array of pivot indices "
@@ -208,27 +212,33 @@ std::pair<sycl::event, sycl::event>
208212 if (getrf_fn == nullptr ) {
209213 throw py::value_error (
210214 " No getrf implementation defined for the provided type "
211- " of the input matrix. " );
215+ " of the input matrix" );
212216 }
213217
214218 auto ipiv_types = dpctl_td_ns::usm_ndarray_types ();
215219 int ipiv_array_type_id =
216220 ipiv_types.typenum_to_lookup_id (ipiv_array.get_typenum ());
217221
218222 if (ipiv_array_type_id != static_cast <int >(dpctl_td_ns::typenum_t ::INT64)) {
219- throw py::value_error (" The type of 'ipiv_array' must be int64. " );
223+ throw py::value_error (" The type of 'ipiv_array' must be int64" );
220224 }
221225
222- const std::int64_t n = a_array.get_shape_raw ()[0 ];
226+ const py::ssize_t *a_array_shape = a_array.get_shape_raw ();
227+ const std::int64_t m = a_array_shape[0 ];
228+ const std::int64_t n = a_array_shape[1 ];
229+ const std::int64_t lda = std::max<size_t >(1UL , m);
230+
231+ if (ipiv_array.get_size () != std::min (m, n)) {
232+ throw py::value_error (" The size of 'ipiv_array' must be min(m, n)" );
233+ }
223234
224235 char *a_array_data = a_array.get_data ();
225- const std::int64_t lda = std::max<size_t >(1UL , n);
226236
227237 char *ipiv_array_data = ipiv_array.get_data ();
228238 std::int64_t *d_ipiv = reinterpret_cast <std::int64_t *>(ipiv_array_data);
229239
230240 std::vector<sycl::event> host_task_events;
231- sycl::event getrf_ev = getrf_fn (exec_q, n, a_array_data, lda, d_ipiv,
241+ sycl::event getrf_ev = getrf_fn (exec_q, m, n, a_array_data, lda, d_ipiv,
232242 dev_info, host_task_events, depends);
233243
234244 sycl::event args_ev = dpctl::utils::keep_args_alive (
0 commit comments