Skip to content

Commit 476009e

Browse files
authored
Merge branch 'master' into strided_dot
2 parents 97817ed + 771653b commit 476009e

File tree

9 files changed

+211
-151
lines changed

9 files changed

+211
-151
lines changed

dpnp/backend/kernels/dpnp_krnl_common.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,9 +1147,21 @@ void func_map_init_linalg(func_map_t &fmap)
11471147
eft_DBL, (void *)dpnp_eig_default_c<double, double>};
11481148

11491149
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_INT][eft_INT] = {
1150-
eft_DBL, (void *)dpnp_eig_ext_c<int32_t, double>};
1150+
get_default_floating_type<>(),
1151+
(void *)dpnp_eig_ext_c<
1152+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1153+
get_default_floating_type<std::false_type>(),
1154+
(void *)dpnp_eig_ext_c<
1155+
int32_t, func_type_map_t::find_type<
1156+
get_default_floating_type<std::false_type>()>>};
11511157
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_LNG][eft_LNG] = {
1152-
eft_DBL, (void *)dpnp_eig_ext_c<int64_t, double>};
1158+
get_default_floating_type<>(),
1159+
(void *)dpnp_eig_ext_c<
1160+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1161+
get_default_floating_type<std::false_type>(),
1162+
(void *)dpnp_eig_ext_c<
1163+
int64_t, func_type_map_t::find_type<
1164+
get_default_floating_type<std::false_type>()>>};
11531165
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_FLT][eft_FLT] = {
11541166
eft_FLT, (void *)dpnp_eig_ext_c<float, float>};
11551167
fmap[DPNPFuncName::DPNP_FN_EIG_EXT][eft_DBL][eft_DBL] = {
@@ -1165,9 +1177,21 @@ void func_map_init_linalg(func_map_t &fmap)
11651177
eft_DBL, (void *)dpnp_eigvals_default_c<double, double>};
11661178

11671179
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_INT][eft_INT] = {
1168-
eft_DBL, (void *)dpnp_eigvals_ext_c<int32_t, double>};
1180+
get_default_floating_type<>(),
1181+
(void *)dpnp_eigvals_ext_c<
1182+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1183+
get_default_floating_type<std::false_type>(),
1184+
(void *)dpnp_eigvals_ext_c<
1185+
int32_t, func_type_map_t::find_type<
1186+
get_default_floating_type<std::false_type>()>>};
11691187
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_LNG][eft_LNG] = {
1170-
eft_DBL, (void *)dpnp_eigvals_ext_c<int64_t, double>};
1188+
get_default_floating_type<>(),
1189+
(void *)dpnp_eigvals_ext_c<
1190+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1191+
get_default_floating_type<std::false_type>(),
1192+
(void *)dpnp_eigvals_ext_c<
1193+
int64_t, func_type_map_t::find_type<
1194+
get_default_floating_type<std::false_type>()>>};
11711195
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_FLT][eft_FLT] = {
11721196
eft_FLT, (void *)dpnp_eigvals_ext_c<float, float>};
11731197
fmap[DPNPFuncName::DPNP_FN_EIGVALS_EXT][eft_DBL][eft_DBL] = {

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -874,16 +874,28 @@ void func_map_init_linalg_func(func_map_t &fmap)
874874
fmap[DPNPFuncName::DPNP_FN_INV][eft_LNG][eft_LNG] = {
875875
eft_DBL, (void *)dpnp_inv_default_c<int64_t, double>};
876876
fmap[DPNPFuncName::DPNP_FN_INV][eft_FLT][eft_FLT] = {
877-
eft_DBL, (void *)dpnp_inv_default_c<float, double>};
877+
eft_DBL, (void *)dpnp_inv_default_c<float, float>};
878878
fmap[DPNPFuncName::DPNP_FN_INV][eft_DBL][eft_DBL] = {
879879
eft_DBL, (void *)dpnp_inv_default_c<double, double>};
880880

881881
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_INT][eft_INT] = {
882-
eft_DBL, (void *)dpnp_inv_ext_c<int32_t, double>};
882+
get_default_floating_type<>(),
883+
(void *)dpnp_inv_ext_c<
884+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
885+
get_default_floating_type<std::false_type>(),
886+
(void *)dpnp_inv_ext_c<
887+
int32_t, func_type_map_t::find_type<
888+
get_default_floating_type<std::false_type>()>>};
883889
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_LNG][eft_LNG] = {
884-
eft_DBL, (void *)dpnp_inv_ext_c<int64_t, double>};
890+
get_default_floating_type<>(),
891+
(void *)dpnp_inv_ext_c<
892+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
893+
get_default_floating_type<std::false_type>(),
894+
(void *)dpnp_inv_ext_c<
895+
int64_t, func_type_map_t::find_type<
896+
get_default_floating_type<std::false_type>()>>};
885897
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_FLT][eft_FLT] = {
886-
eft_DBL, (void *)dpnp_inv_ext_c<float, double>};
898+
eft_FLT, (void *)dpnp_inv_ext_c<float, float>};
887899
fmap[DPNPFuncName::DPNP_FN_INV_EXT][eft_DBL][eft_DBL] = {
888900
eft_DBL, (void *)dpnp_inv_ext_c<double, double>};
889901

@@ -1039,9 +1051,21 @@ void func_map_init_linalg_func(func_map_t &fmap)
10391051
// eft_C128, (void*)dpnp_qr_c<std::complex<double>, std::complex<double>>};
10401052

10411053
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_INT][eft_INT] = {
1042-
eft_DBL, (void *)dpnp_qr_ext_c<int32_t, double>};
1054+
get_default_floating_type<>(),
1055+
(void *)dpnp_qr_ext_c<
1056+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1057+
get_default_floating_type<std::false_type>(),
1058+
(void *)dpnp_qr_ext_c<
1059+
int32_t, func_type_map_t::find_type<
1060+
get_default_floating_type<std::false_type>()>>};
10431061
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_LNG][eft_LNG] = {
1044-
eft_DBL, (void *)dpnp_qr_ext_c<int64_t, double>};
1062+
get_default_floating_type<>(),
1063+
(void *)dpnp_qr_ext_c<
1064+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>>,
1065+
get_default_floating_type<std::false_type>(),
1066+
(void *)dpnp_qr_ext_c<
1067+
int64_t, func_type_map_t::find_type<
1068+
get_default_floating_type<std::false_type>()>>};
10451069
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_FLT][eft_FLT] = {
10461070
eft_FLT, (void *)dpnp_qr_ext_c<float, float>};
10471071
fmap[DPNPFuncName::DPNP_FN_QR_EXT][eft_DBL][eft_DBL] = {
@@ -1062,9 +1086,29 @@ void func_map_init_linalg_func(func_map_t &fmap)
10621086
std::complex<double>, double>};
10631087

10641088
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_INT][eft_INT] = {
1065-
eft_DBL, (void *)dpnp_svd_ext_c<int32_t, double, double>};
1089+
get_default_floating_type<>(),
1090+
(void *)dpnp_svd_ext_c<
1091+
int32_t, func_type_map_t::find_type<get_default_floating_type<>()>,
1092+
func_type_map_t::find_type<get_default_floating_type<>()>>,
1093+
get_default_floating_type<std::false_type>(),
1094+
(void *)
1095+
dpnp_svd_ext_c<int32_t,
1096+
func_type_map_t::find_type<
1097+
get_default_floating_type<std::false_type>()>,
1098+
func_type_map_t::find_type<
1099+
get_default_floating_type<std::false_type>()>>};
10661100
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_LNG][eft_LNG] = {
1067-
eft_DBL, (void *)dpnp_svd_ext_c<int64_t, double, double>};
1101+
get_default_floating_type<>(),
1102+
(void *)dpnp_svd_ext_c<
1103+
int64_t, func_type_map_t::find_type<get_default_floating_type<>()>,
1104+
func_type_map_t::find_type<get_default_floating_type<>()>>,
1105+
get_default_floating_type<std::false_type>(),
1106+
(void *)
1107+
dpnp_svd_ext_c<int64_t,
1108+
func_type_map_t::find_type<
1109+
get_default_floating_type<std::false_type>()>,
1110+
func_type_map_t::find_type<
1111+
get_default_floating_type<std::false_type>()>>};
10681112
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_FLT][eft_FLT] = {
10691113
eft_FLT, (void *)dpnp_svd_ext_c<float, float, float>};
10701114
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_DBL][eft_DBL] = {

dpnp/backend/src/dpnp_fptr.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,17 @@ class dpnp_less_comp
260260
}
261261
};
262262

263+
/**
264+
* A template function that determines the default floating-point type
265+
* based on the value of the template parameter has_fp64.
266+
*/
267+
template <typename has_fp64 = std::true_type>
268+
static constexpr DPNPFuncType get_default_floating_type()
269+
{
270+
return has_fp64::value ? DPNPFuncType::DPNP_FT_DOUBLE
271+
: DPNPFuncType::DPNP_FT_FLOAT;
272+
}
273+
263274
/**
264275
* FPTR interface initialization functions
265276
*/

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ cdef extern from "dpnp_iface_fptr.hpp":
336336
struct DPNPFuncData:
337337
DPNPFuncType return_type
338338
void * ptr
339+
DPNPFuncType return_type_no_fp64
340+
void *ptr_no_fp64
339341

340342
DPNPFuncData get_dpnp_function_ptr(DPNPFuncName name, DPNPFuncType first_type, DPNPFuncType second_type) except +
341343

0 commit comments

Comments
 (0)