2626
2727#include < pybind11/pybind11.h>
2828
29+ // dpctl tensor headers
30+ #include " utils/memory_overlap.hpp"
31+ #include " utils/type_utils.hpp"
32+
2933#include " heevd.hpp"
34+ #include " types_matrix.hpp"
3035
3136#include " dpnp_utils.hpp"
3237
@@ -42,19 +47,34 @@ namespace lapack
4247
4348namespace mkl_lapack = oneapi::mkl::lapack;
4449namespace py = pybind11;
50+ namespace type_utils = dpctl::tensor::type_utils;
51+
52+ typedef sycl::event (*heevd_impl_fn_ptr_t )(sycl::queue,
53+ const oneapi::mkl::job,
54+ const oneapi::mkl::uplo,
55+ const std::int64_t ,
56+ char *,
57+ char *,
58+ std::vector<sycl::event>&,
59+ const std::vector<sycl::event>&);
60+
61+ static heevd_impl_fn_ptr_t heevd_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];
4562
4663template <typename T, typename RealT>
47- static sycl::event call_heevd (sycl::queue exec_q,
64+ static sycl::event heevd_impl (sycl::queue exec_q,
4865 const oneapi::mkl::job jobz,
4966 const oneapi::mkl::uplo upper_lower,
5067 const std::int64_t n,
51- T* a ,
52- RealT* w ,
68+ char * in_a ,
69+ char * out_w ,
5370 std::vector<sycl::event>& host_task_events,
5471 const std::vector<sycl::event>& depends)
5572{
56- validate_type_for_device<T>(exec_q);
57- validate_type_for_device<RealT>(exec_q);
73+ type_utils::validate_type_for_device<T>(exec_q);
74+ type_utils::validate_type_for_device<RealT>(exec_q);
75+
76+ T* a = reinterpret_cast <T*>(in_a);
77+ RealT* w = reinterpret_cast <RealT*>(out_w);
5878
5979 const std::int64_t lda = std::max<size_t >(1UL , n);
6080 const std::int64_t scratchpad_size = mkl_lapack::heevd_scratchpad_size<T>(exec_q, jobz, upper_lower, n, lda);
@@ -163,13 +183,11 @@ std::pair<sycl::event, sycl::event> heevd(sycl::queue exec_q,
163183 throw py::value_error (" Execution queue is not compatible with allocation queues" );
164184 }
165185
166- // check that arrays do not overlap, and concurrent access is safe.
167- // TODO: need to be exposed by DPCTL headers
168- // auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
169- // if (overlap(eig_vecs, eig_vals))
170- // {
171- // throw py::value_error("Arrays index overlapping segments of memory");
172- // }
186+ auto const & overlap = dpctl::tensor::overlap::MemoryOverlap ();
187+ if (overlap (eig_vecs, eig_vals))
188+ {
189+ throw py::value_error (" Arrays with eigenvectors and eigenvalues are overlapping segments of memory" );
190+ }
173191
174192 bool is_eig_vecs_f_contig = eig_vecs.is_f_contiguous ();
175193 bool is_eig_vals_c_contig = eig_vals.is_c_contiguous ();
@@ -182,38 +200,51 @@ std::pair<sycl::event, sycl::event> heevd(sycl::queue exec_q,
182200 throw py::value_error (" An array with output eigenvalues must be C-contiguous" );
183201 }
184202
185- int eig_vecs_typenum = eig_vecs. get_typenum ();
186- int eig_vals_typenum = eig_vals. get_typenum ();
187- auto const & dpctl_capi = dpctl::detail::dpctl_capi::get ( );
203+ auto array_types = dpctl_td_ns::usm_ndarray_types ();
204+ int eig_vecs_type_id = array_types. typenum_to_lookup_id (eig_vecs. get_typenum () );
205+ int eig_vals_type_id = array_types. typenum_to_lookup_id (eig_vals. get_typenum () );
188206
189- sycl::event heevd_ev;
190- std::vector<sycl::event> host_task_events;
207+ heevd_impl_fn_ptr_t heevd_fn = heevd_dispatch_table[eig_vecs_type_id][eig_vals_type_id];
208+ if (heevd_fn == nullptr )
209+ {
210+ throw py::value_error (" No heevd implementation defined for a pair of type for eigenvectors and eigenvalues" );
211+ }
212+
213+ char * eig_vecs_data = eig_vecs.get_data ();
214+ char * eig_vals_data = eig_vals.get_data ();
191215
192216 const std::int64_t n = eig_vecs_shape[0 ];
193217 const oneapi::mkl::job jobz_val = static_cast <oneapi::mkl::job>(jobz);
194218 const oneapi::mkl::uplo uplo_val = static_cast <oneapi::mkl::uplo>(upper_lower);
195219
196- if ((eig_vecs_typenum == dpctl_capi.UAR_CDOUBLE_ ) && (eig_vals_typenum == dpctl_capi.UAR_DOUBLE_ ))
197- {
198- std::complex <double >* a = reinterpret_cast <std::complex <double >*>(eig_vecs.get_data ());
199- double * w = reinterpret_cast <double *>(eig_vals.get_data ());
220+ std::vector<sycl::event> host_task_events;
221+ sycl::event heevd_ev =
222+ heevd_fn (exec_q, jobz_val, uplo_val, n, eig_vecs_data, eig_vals_data, host_task_events, depends);
200223
201- heevd_ev = call_heevd (exec_q, jobz_val, uplo_val, n, a, w, host_task_events, depends);
202- }
203- else if ((eig_vecs_typenum == dpctl_capi.UAR_CFLOAT_ ) && (eig_vals_typenum == dpctl_capi.UAR_FLOAT_ ))
204- {
205- std::complex <float >* a = reinterpret_cast <std::complex <float >*>(eig_vecs.get_data ());
206- float * w = reinterpret_cast <float *>(eig_vals.get_data ());
224+ sycl::event args_ev = dpctl::utils::keep_args_alive (exec_q, {eig_vecs, eig_vals}, host_task_events);
225+ return std::make_pair (args_ev, heevd_ev);
226+ }
207227
208- heevd_ev = call_heevd (exec_q, jobz_val, uplo_val, n, a, w, host_task_events, depends);
209- }
210- else
228+ template <typename fnT, typename T, typename RealT>
229+ struct HeevdContigFactory
230+ {
231+ fnT get ()
211232 {
212- throw py::value_error (" Unexpected types of either eigenvectors or eigenvalues" );
233+ if constexpr (types::HeevdTypePairSupportFactory<T, RealT>::is_defined)
234+ {
235+ return heevd_impl<T, RealT>;
236+ }
237+ else
238+ {
239+ return nullptr ;
240+ }
213241 }
242+ };
214243
215- sycl::event args_ev = dpctl::utils::keep_args_alive (exec_q, {eig_vecs, eig_vals}, host_task_events);
216- return std::make_pair (args_ev, heevd_ev);
244+ void init_heevd_dispatch_table (void )
245+ {
246+ dpctl_td_ns::DispatchTableBuilder<heevd_impl_fn_ptr_t , HeevdContigFactory, dpctl_td_ns::num_types> contig;
247+ contig.populate_dispatch_table (heevd_dispatch_table);
217248}
218249}
219250}
0 commit comments