Skip to content

Commit 0959c79

Browse files
Update dpnp.linalg.svd
1 parent f35bba6 commit 0959c79

File tree

4 files changed

+45
-22
lines changed

4 files changed

+45
-22
lines changed

dpnp/backend/kernels/dpnp_krnl_linalg.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,9 +1099,29 @@ void func_map_init_linalg_func(func_map_t &fmap)
10991099
std::complex<double>, double>};
11001100

11011101
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_INT][eft_INT] = {
1102-
eft_DBL, (void *)dpnp_svd_ext_c<int32_t, double, double>};
1102+
get_res_type_with_aspect<>(),
1103+
(void *)dpnp_svd_ext_c<
1104+
int32_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>,
1105+
func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
1106+
get_res_type_with_aspect<std::false_type>(),
1107+
(void *)
1108+
dpnp_svd_ext_c<int32_t,
1109+
func_type_map_t::find_type<
1110+
get_res_type_with_aspect<std::false_type>()>,
1111+
func_type_map_t::find_type<
1112+
get_res_type_with_aspect<std::false_type>()>>};
11031113
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_LNG][eft_LNG] = {
1104-
eft_DBL, (void *)dpnp_svd_ext_c<int64_t, double, double>};
1114+
get_res_type_with_aspect<>(),
1115+
(void *)dpnp_svd_ext_c<
1116+
int64_t, func_type_map_t::find_type<get_res_type_with_aspect<>()>,
1117+
func_type_map_t::find_type<get_res_type_with_aspect<>()>>,
1118+
get_res_type_with_aspect<std::false_type>(),
1119+
(void *)
1120+
dpnp_svd_ext_c<int64_t,
1121+
func_type_map_t::find_type<
1122+
get_res_type_with_aspect<std::false_type>()>,
1123+
func_type_map_t::find_type<
1124+
get_res_type_with_aspect<std::false_type>()>>};
11051125
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_FLT][eft_FLT] = {
11061126
eft_FLT, (void *)dpnp_svd_ext_c<float, float, float>};
11071127
fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_DBL][eft_DBL] = {

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -533,26 +533,31 @@ cpdef tuple dpnp_svd(utils.dpnp_descriptor x1, cpp_bool full_matrices, cpp_bool
533533
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
534534
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SVD_EXT, param1_type, param1_type)
535535

536-
cdef DPNPFuncType type_s = DPNP_FT_DOUBLE
537-
if x1.dtype == dpnp.float32:
538-
type_s = DPNP_FT_FLOAT
539-
540536
x1_obj = x1.get_array()
541537

538+
cdef custom_linalg_1in_3out_shape_t func = NULL
539+
cdef DPNPFuncType return_type = DPNP_FT_NONE
540+
if dpnp.issubdtype(x1_obj.dtype, dpnp.integer) and not x1_obj.sycl_device.has_aspect_fp64:
541+
return_type = kernel_data.return_type_no_fp64
542+
func = < custom_linalg_1in_3out_shape_t > kernel_data.ptr_no_fp64
543+
else:
544+
return_type = kernel_data.return_type
545+
func = < custom_linalg_1in_3out_shape_t > kernel_data.ptr
546+
542547
cdef utils.dpnp_descriptor res_u = utils.create_output_descriptor((size_m, size_m),
543-
kernel_data.return_type,
548+
return_type,
544549
None,
545550
device=x1_obj.sycl_device,
546551
usm_type=x1_obj.usm_type,
547552
sycl_queue=x1_obj.sycl_queue)
548553
cdef utils.dpnp_descriptor res_s = utils.create_output_descriptor((size_s, ),
549-
type_s,
554+
return_type,
550555
None,
551556
device=x1_obj.sycl_device,
552557
usm_type=x1_obj.usm_type,
553558
sycl_queue=x1_obj.sycl_queue)
554559
cdef utils.dpnp_descriptor res_vt = utils.create_output_descriptor((size_n, size_n),
555-
kernel_data.return_type,
560+
return_type,
556561
None,
557562
device=x1_obj.sycl_device,
558563
usm_type=x1_obj.usm_type,
@@ -563,8 +568,6 @@ cpdef tuple dpnp_svd(utils.dpnp_descriptor x1, cpp_bool full_matrices, cpp_bool
563568
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
564569
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
565570

566-
cdef custom_linalg_1in_3out_shape_t func = < custom_linalg_1in_3out_shape_t > kernel_data.ptr
567-
568571
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
569572
x1.get_data(),
570573
res_u.get_data(),

tests/skipped_tests.tbl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,6 @@ tests/test_linalg.py::test_norm1[None-3-[7]]
112112
tests/test_linalg.py::test_norm1[None-3-[1, 2]]
113113
tests/test_linalg.py::test_norm1[None-3-[1, 0]]
114114

115-
tests/test_linalg.py::test_svd[(2,2)-float64]
116-
tests/test_linalg.py::test_svd[(3,4)-float64]
117-
tests/test_linalg.py::test_svd[(5,3)-float64]
118-
tests/test_linalg.py::test_svd[(16,16)-float64]
119-
120115
tests/test_logic.py::test_allclose[int32]
121116
tests/test_logic.py::test_allclose[int64]
122117
tests/test_logic.py::test_allclose[float32]

tests/test_linalg.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -396,17 +396,22 @@ def test_svd(type, shape):
396396
np_u, np_s, np_vt = numpy.linalg.svd(a)
397397
dpnp_u, dpnp_s, dpnp_vt = inp.linalg.svd(ia)
398398

399-
assert dpnp_u.dtype == np_u.dtype
400-
assert dpnp_s.dtype == np_s.dtype
401-
assert dpnp_vt.dtype == np_vt.dtype
399+
support_aspect64 = has_support_aspect64()
400+
401+
if support_aspect64:
402+
assert dpnp_u.dtype == np_u.dtype
403+
assert dpnp_s.dtype == np_s.dtype
404+
assert dpnp_vt.dtype == np_vt.dtype
402405
assert dpnp_u.shape == np_u.shape
403406
assert dpnp_s.shape == np_s.shape
404407
assert dpnp_vt.shape == np_vt.shape
405408

406-
if type == numpy.float32:
409+
tol = 1e-12
410+
if support_aspect64:
411+
if type == numpy.float32:
412+
tol = 1e-03
413+
elif type in (numpy.float32, numpy.int32, numpy.int64, None):
407414
tol = 1e-03
408-
else:
409-
tol = 1e-12
410415

411416
# check decomposition
412417
dpnp_diag_s = inp.zeros(shape, dtype=dpnp_s.dtype)

0 commit comments

Comments
 (0)