Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/generate_coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
shell: bash -l {0}

env:
python-ver: '3.11'
python-ver: '3.10'
CHANNELS: '-c dppy/label/dev -c intel -c conda-forge --override-channels'

steps:
Expand Down
45 changes: 31 additions & 14 deletions dpnp/backend/kernels/dpnp_krnl_fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
const size_t norm,
const DPCTLEventVectorRef dep_event_vec_ref)
{
static_assert(sycl::detail::is_complex<_DataType_output>::value,
"Output data type must be a complex type.");

DPCTLSyclEventRef event_ref = nullptr;

if (!shape_size || !array1_in || !result_out) {
Expand Down Expand Up @@ -476,8 +479,10 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
{
double *array1_copy = reinterpret_cast<double *>(
dpnp_memory_alloc_c(q_ref, input_size * sizeof(double)));
using CastType = typename _DataType_output::value_type;

CastType *array1_copy = reinterpret_cast<CastType *>(
dpnp_memory_alloc_c(q_ref, input_size * sizeof(CastType)));

shape_elem_type *copy_strides = reinterpret_cast<shape_elem_type *>(
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
Expand All @@ -486,15 +491,17 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
*copy_shape = input_size;
shape_elem_type copy_shape_size = 1;
event_ref = dpnp_copyto_c<_DataType_input, double>(
event_ref = dpnp_copyto_c<_DataType_input, CastType>(
q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
copy_strides, array1_in, input_size, copy_shape_size,
copy_shape, copy_strides, NULL, dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);

event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double,
desc_dp_real_t>(
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
CastType, CastType,
std::conditional_t<std::is_same<CastType, double>::value,
desc_dp_real_t, desc_sp_real_t>>(
q_ref, array1_copy, result_out, input_shape, result_shape,
shape_size, input_size, result_size, inverse, norm, 0);

Expand Down Expand Up @@ -577,6 +584,8 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
const size_t norm,
const DPCTLEventVectorRef dep_event_vec_ref)
{
static_assert(sycl::detail::is_complex<_DataType_output>::value,
"Output data type must be a complex type.");
DPCTLSyclEventRef event_ref = nullptr;

if (!shape_size || !array1_in || !result_out) {
Expand Down Expand Up @@ -617,8 +626,10 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
{
double *array1_copy = reinterpret_cast<double *>(
dpnp_memory_alloc_c(q_ref, input_size * sizeof(double)));
using CastType = typename _DataType_output::value_type;

CastType *array1_copy = reinterpret_cast<CastType *>(
dpnp_memory_alloc_c(q_ref, input_size * sizeof(CastType)));

shape_elem_type *copy_strides = reinterpret_cast<shape_elem_type *>(
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
Expand All @@ -627,15 +638,17 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type)));
*copy_shape = input_size;
shape_elem_type copy_shape_size = 1;
event_ref = dpnp_copyto_c<_DataType_input, double>(
event_ref = dpnp_copyto_c<_DataType_input, CastType>(
q_ref, array1_copy, input_size, copy_shape_size, copy_shape,
copy_strides, array1_in, input_size, copy_shape_size,
copy_shape, copy_strides, NULL, dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);

event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double,
desc_dp_real_t>(
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<
CastType, CastType,
std::conditional_t<std::is_same<CastType, double>::value,
desc_dp_real_t, desc_sp_real_t>>(
q_ref, array1_copy, result_out, input_shape, result_shape,
shape_size, input_size, result_size, inverse, norm, 1);

Expand Down Expand Up @@ -721,9 +734,11 @@ void func_map_init_fft_func(func_map_t &fmap)
dpnp_fft_fft_default_c<std::complex<double>, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_INT][eft_INT] = {
eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<double>>};
eft_C128, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<double>>,
eft_C64, (void *)dpnp_fft_fft_ext_c<int32_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_LNG][eft_LNG] = {
eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<double>>};
eft_C128, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<double>>,
eft_C64, (void *)dpnp_fft_fft_ext_c<int64_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_FLT][eft_FLT] = {
eft_C64, (void *)dpnp_fft_fft_ext_c<float, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_FFT_EXT][eft_DBL][eft_DBL] = {
Expand All @@ -748,9 +763,11 @@ void func_map_init_fft_func(func_map_t &fmap)
(void *)dpnp_fft_rfft_default_c<double, std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_INT][eft_INT] = {
eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<double>>};
eft_C128, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<double>>,
eft_C64, (void *)dpnp_fft_rfft_ext_c<int32_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_LNG][eft_LNG] = {
eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<double>>};
eft_C128, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<double>>,
eft_C64, (void *)dpnp_fft_rfft_ext_c<int64_t, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_FLT][eft_FLT] = {
eft_C64, (void *)dpnp_fft_rfft_ext_c<float, std::complex<float>>};
fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_DBL][eft_DBL] = {
Expand Down
18 changes: 14 additions & 4 deletions dpnp/fft/dpnp_algo_fft.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,15 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,

input_obj = input.get_array()

# get FPTR function and return type
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
input_obj.sycl_device.has_aspect_fp64)
cdef DPNPFuncType return_type = ret_type_and_func[0]
cdef fptr_dpnp_fft_fft_t func = < fptr_dpnp_fft_fft_t > ret_type_and_func[1]

# ceate result array with type given by FPTR data
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(output_shape,
kernel_data.return_type,
return_type,
None,
device=input_obj.sycl_device,
usm_type=input_obj.usm_type,
Expand All @@ -81,7 +87,6 @@ cpdef utils.dpnp_descriptor dpnp_fft(utils.dpnp_descriptor input,
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
# call FPTR function
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
input.get_data(),
Expand Down Expand Up @@ -122,9 +127,15 @@ cpdef utils.dpnp_descriptor dpnp_rfft(utils.dpnp_descriptor input,

input_obj = input.get_array()

# get FPTR function and return type
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
input_obj.sycl_device.has_aspect_fp64)
cdef DPNPFuncType return_type = ret_type_and_func[0]
cdef fptr_dpnp_fft_fft_t func = < fptr_dpnp_fft_fft_t > ret_type_and_func[1]

# ceate result array with type given by FPTR data
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(output_shape,
kernel_data.return_type,
return_type,
None,
device=input_obj.sycl_device,
usm_type=input_obj.usm_type,
Expand All @@ -135,7 +146,6 @@ cpdef utils.dpnp_descriptor dpnp_rfft(utils.dpnp_descriptor input,
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef fptr_dpnp_fft_fft_t func = <fptr_dpnp_fft_fft_t > kernel_data.ptr
# call FPTR function
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
input.get_data(),
Expand Down
Loading