Skip to content

Commit 8964e2d

Browse files
authored
[DFT] Support AdaptiveCpp in cuFFT and rocFFT backends (#665)
1 parent 657f502 commit 8964e2d

File tree

10 files changed

+66
-37
lines changed

10 files changed

+66
-37
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ Supported compilers include:
338338
<tr>
339339
<td rowspan=2 align="center">NVIDIA GPU</td>
340340
<td align="center">NVIDIA cuFFT</td>
341-
<td align="center">Open DPC++</td>
341+
<td align="center">Open DPC++</br>AdaptiveCpp</td>
342342
<td align="center">Dynamic, Static</td>
343343
</tr>
344344
<tr>
@@ -349,7 +349,7 @@ Supported compilers include:
349349
<tr>
350350
<td rowspan=2 align="center">AMD GPU</td>
351351
<td align="center">AMD rocFFT</td>
352-
<td align="center">Open DPC++</td>
352+
<td align="center">Open DPC++</br>AdaptiveCpp</td>
353353
<td align="center">Dynamic, Static</td>
354354
</tr>
355355
<tr>

docs/building_the_project_with_adaptivecpp.rst

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ additional guidance. The target architectures must be specified with
5757
``HIP_TARGETS``. See the `AdaptiveCpp documentation
5858
<https://github.com/AdaptiveCpp/AdaptiveCpp/blob/develop/doc/using-hipsycl.md#adaptivecpp-targets-specification>`_.
5959

60-
If a backend library supports multiple domains (i.e. BLAS, RNG), it may be
60+
If a backend library supports multiple domains (i.e. BLAS, DFT, RNG), it may be
6161
desirable to only enable selected domains. For this, the ``TARGET_DOMAINS``
6262
variable should be set. For further details, see :ref:`_build_target_domains`.
6363

@@ -81,6 +81,9 @@ The most important supported build options are:
8181
* - ENABLE_CUBLAS_BACKEND
8282
- True, False
8383
- False
84+
* - ENABLE_CUFFT_BACKEND
85+
- True, False
86+
- False
8487
* - ENABLE_CURAND_BACKEND
8588
- True, False
8689
- False
@@ -93,6 +96,9 @@ The most important supported build options are:
9396
* - ENABLE_ROCBLAS_BACKEND
9497
- True, False
9598
- False
99+
* - ENABLE_ROCFFT_BACKEND
100+
- True, False
101+
- False
96102
* - ENABLE_ROCRAND_BACKEND
97103
- True, False
98104
- False
@@ -106,7 +112,7 @@ The most important supported build options are:
106112
- True, False
107113
- True
108114
* - TARGET_DOMAINS (list)
109-
- blas, rng
115+
- blas, dft, rng
110116
- All supported domains
111117

112118
Some additional build options are given in
@@ -120,8 +126,8 @@ Backends
120126
Building for CUDA
121127
~~~~~~~~~~~~~~~~~
122128

123-
The CUDA backends can be enabled with ``ENABLE_CUBLAS_BACKEND`` and
124-
``ENABLE_CURAND_BACKEND``.
129+
The CUDA backends can be enabled with ``ENABLE_CUBLAS_BACKEND``,
130+
``ENABLE_CUFFT_BACKEND`` and ``ENABLE_CURAND_BACKEND``.
125131

126132
The target architecture must be set using the ``HIPSYCL_TARGETS`` parameter. For
127133
example, to target a Nvidia A100 (Ampere architecture), set
@@ -140,8 +146,8 @@ the CUDA libraries should be found automatically by CMake.
140146
Building for ROCm
141147
~~~~~~~~~~~~~~~~~
142148

143-
The ROCm backends can be enabled with ``ENABLE_ROCBLAS_BACKEND`` and
144-
``ENABLE_ROCRAND_BACKEND``.
149+
The ROCm backends can be enabled with ``ENABLE_ROCBLAS_BACKEND``,
150+
``ENABLE_ROCFFT_BACKEND`` and ``ENABLE_ROCRAND_BACKEND``.
145151

146152
The target architecture must be set using the ``HIPSYCL_TARGETS`` parameter. See
147153
the `AdaptiveCpp documentation

src/dft/backends/cufft/backward.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ ONEMATH_EXPORT void compute_backward(descriptor_type& desc,
7676
auto stream = detail::setup_stream(func_name, ih, plan);
7777

7878
auto inout_native = reinterpret_cast<fwd<descriptor_type>*>(
79-
ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(inout_acc));
79+
ih.get_native_mem<detail::sycl_cuda_backend>(inout_acc));
8080
detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
8181
func_name, stream, plan, reinterpret_cast<void*>(inout_native + offsets[0]),
8282
reinterpret_cast<void*>(inout_native + offsets[1]));
@@ -121,14 +121,14 @@ ONEMATH_EXPORT void compute_backward(descriptor_type& desc,
121121
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
122122
auto stream = detail::setup_stream(func_name, ih, plan);
123123

124-
auto in_native = reinterpret_cast<void*>(
125-
reinterpret_cast<bwd<descriptor_type>*>(
126-
ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(in_acc)) +
127-
offsets[0]);
128-
auto out_native = reinterpret_cast<void*>(
129-
reinterpret_cast<fwd<descriptor_type>*>(
130-
ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(out_acc)) +
131-
offsets[1]);
124+
auto in_native =
125+
reinterpret_cast<void*>(reinterpret_cast<bwd<descriptor_type>*>(
126+
ih.get_native_mem<detail::sycl_cuda_backend>(in_acc)) +
127+
offsets[0]);
128+
auto out_native =
129+
reinterpret_cast<void*>(reinterpret_cast<fwd<descriptor_type>*>(
130+
ih.get_native_mem<detail::sycl_cuda_backend>(out_acc)) +
131+
offsets[1]);
132132
detail::cufft_execute<detail::Direction::Backward, fwd<descriptor_type>>(
133133
func_name, stream, plan, in_native, out_native);
134134
});

src/dft/backends/cufft/commit.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#include "oneapi/math/dft/detail/cufft/onemath_dft_cufft.hpp"
3535
#include "oneapi/math/dft/types.hpp"
3636

37+
#include "execute_helper.hpp"
38+
#include "../../execute_helper_generic.hpp"
3739
#include "../stride_helper.hpp"
3840

3941
#include <cufft.h>
@@ -84,7 +86,7 @@ class cufft_commit final : public dft::detail::commit_impl<prec, dom> {
8486
if (fix_context) {
8587
// cufftDestroy changes the context so change it back.
8688
CUdevice interopDevice =
87-
sycl::get_native<sycl::backend::ext_oneapi_cuda>(this->get_queue().get_device());
89+
sycl::get_native<sycl_cuda_backend>(this->get_queue().get_device());
8890
CUcontext interopContext;
8991
if (cuDevicePrimaryCtxRetain(&interopContext, interopDevice) != CUDA_SUCCESS) {
9092
throw math::exception("dft/backends/cufft", __FUNCTION__,
@@ -353,16 +355,16 @@ class cufft_commit final : public dft::detail::commit_impl<prec, dom> {
353355
.submit([&](sycl::handler& cgh) {
354356
auto workspace_acc =
355357
buffer_workspace.template get_access<sycl::access::mode::read_write>(cgh);
356-
cgh.host_task([=](sycl::interop_handle ih) {
357-
auto stream = ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
358+
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
359+
auto stream = ih.get_native_queue<sycl_cuda_backend>();
358360
auto result = cufftSetStream(plan, stream);
359361
if (result != CUFFT_SUCCESS) {
360362
throw oneapi::math::exception(
361363
"dft/backends/cufft", "set_workspace",
362364
"cufftSetStream returned " + std::to_string(result));
363365
}
364366
auto workspace_native = reinterpret_cast<scalar_type*>(
365-
ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(workspace_acc));
367+
ih.get_native_mem<sycl_cuda_backend>(workspace_acc));
366368
cufftSetWorkArea(plan, workspace_native);
367369
});
368370
})

src/dft/backends/cufft/execute_helper.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636

3737
namespace oneapi::math::dft::cufft::detail {
3838

39+
#if defined(__ADAPTIVECPP__) || defined(__HIPSYCL__)
40+
constexpr auto sycl_cuda_backend{ sycl::backend::cuda };
41+
#else // DPC++
42+
constexpr auto sycl_cuda_backend{ sycl::backend::ext_oneapi_cuda };
43+
#endif
44+
3945
template <dft::precision prec, dft::domain dom>
4046
inline dft::detail::commit_impl<prec, dom>* checked_get_commit(
4147
dft::detail::descriptor<prec, dom>& desc) {
@@ -142,7 +148,7 @@ void cufft_execute(const std::string& func, CUstream stream, cufftHandle plan, v
142148
}
143149

144150
inline CUstream setup_stream(const std::string& func, sycl::interop_handle ih, cufftHandle plan) {
145-
auto stream = ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
151+
auto stream = ih.get_native_queue<sycl_cuda_backend>();
146152
auto result = cufftSetStream(plan, stream);
147153
if (result != CUFFT_SUCCESS) {
148154
throw oneapi::math::exception("dft/backends/cufft", func,

src/dft/backends/cufft/forward.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ ONEMATH_EXPORT void compute_forward(descriptor_type& desc,
7979
auto stream = detail::setup_stream(func_name, ih, plan);
8080

8181
auto inout_native = reinterpret_cast<fwd<descriptor_type>*>(
82-
ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(inout_acc));
82+
ih.get_native_mem<detail::sycl_cuda_backend>(inout_acc));
8383
detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
8484
func_name, stream, plan, reinterpret_cast<void*>(inout_native + offsets[0]),
8585
reinterpret_cast<void*>(inout_native + offsets[1]));
@@ -124,14 +124,14 @@ ONEMATH_EXPORT void compute_forward(descriptor_type& desc,
124124
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
125125
auto stream = detail::setup_stream(func_name, ih, plan);
126126

127-
auto in_native = reinterpret_cast<void*>(
128-
reinterpret_cast<fwd<descriptor_type>*>(
129-
ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(in_acc)) +
130-
offsets[0]);
131-
auto out_native = reinterpret_cast<void*>(
132-
reinterpret_cast<bwd<descriptor_type>*>(
133-
ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(out_acc)) +
134-
offsets[1]);
127+
auto in_native =
128+
reinterpret_cast<void*>(reinterpret_cast<fwd<descriptor_type>*>(
129+
ih.get_native_mem<detail::sycl_cuda_backend>(in_acc)) +
130+
offsets[0]);
131+
auto out_native =
132+
reinterpret_cast<void*>(reinterpret_cast<bwd<descriptor_type>*>(
133+
ih.get_native_mem<detail::sycl_cuda_backend>(out_acc)) +
134+
offsets[1]);
135135
detail::cufft_execute<detail::Direction::Forward, fwd<descriptor_type>>(
136136
func_name, stream, plan, in_native, out_native);
137137
});

src/dft/backends/rocfft/commit.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#include "oneapi/math/dft/detail/rocfft/onemath_dft_rocfft.hpp"
3535
#include "oneapi/math/dft/types.hpp"
3636

37+
#include "execute_helper.hpp"
38+
#include "../../execute_helper_generic.hpp"
3739
#include "../stride_helper.hpp"
3840

3941
#include "rocfft_handle.hpp"
@@ -557,9 +559,9 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
557559
this->get_queue().submit([&](sycl::handler& cgh) {
558560
auto workspace_acc =
559561
buffer_workspace.template get_access<sycl::access::mode::read_write>(cgh);
560-
cgh.host_task([=](sycl::interop_handle ih) {
562+
dft::detail::fft_enqueue_task(cgh, [=](sycl::interop_handle ih) {
561563
auto workspace_native = reinterpret_cast<scalar_type*>(
562-
ih.get_native_mem<sycl::backend::ext_oneapi_hip>(workspace_acc));
564+
ih.get_native_mem<sycl_hip_backend>(workspace_acc));
563565
set_workspace_impl(handle, workspace_native, workspace_bytes, "set_workspace");
564566
});
565567
});

src/dft/backends/rocfft/execute_helper.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636

3737
namespace oneapi::math::dft::rocfft::detail {
3838

39+
#if defined(__ADAPTIVECPP__) || defined(__HIPSYCL__)
40+
constexpr auto sycl_hip_backend{ sycl::backend::hip };
41+
#else // DPC++
42+
constexpr auto sycl_hip_backend{ sycl::backend::ext_oneapi_hip };
43+
#endif
44+
3945
template <dft::precision prec, dft::domain dom>
4046
inline dft::detail::commit_impl<prec, dom>* checked_get_commit(
4147
dft::detail::descriptor<prec, dom>& desc) {
@@ -60,12 +66,12 @@ inline auto expect_config(DescT& desc, const char* message) {
6066

6167
template <typename Acc>
6268
inline void* native_mem(sycl::interop_handle& ih, Acc& buf) {
63-
return ih.get_native_mem<sycl::backend::ext_oneapi_hip>(buf);
69+
return ih.get_native_mem<sycl_hip_backend>(buf);
6470
}
6571

6672
inline hipStream_t setup_stream(const std::string& func, sycl::interop_handle& ih,
6773
rocfft_execution_info info) {
68-
auto stream = ih.get_native_queue<sycl::backend::ext_oneapi_hip>();
74+
auto stream = ih.get_native_queue<sycl_hip_backend>();
6975
auto result = rocfft_execution_info_set_stream(info, stream);
7076
if (result != rocfft_status_success) {
7177
throw oneapi::math::exception(

src/dft/execute_helper_generic.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ namespace oneapi::math::dft::detail {
3939
*/
4040
template <typename HandlerT, typename FnT>
4141
static inline void fft_enqueue_task(HandlerT&& cgh, FnT&& f) {
42-
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
42+
#if defined(__ADAPTIVECPP__)
43+
cgh.AdaptiveCpp_enqueue_custom_operation([=](sycl::interop_handle ih) {
44+
#elif defined(__HIPSYCL__)
45+
cgh.hipSYCL_enqueue_custom_operation([=](sycl::interop_handle ih) {
46+
#elif defined(SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND)
4347
cgh.ext_codeplay_enqueue_native_command([=](sycl::interop_handle ih) {
4448
#else
4549
cgh.host_task([=](sycl::interop_handle ih) {

tests/unit_tests/dft/source/descriptor_tests.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,9 @@ inline void recommit_values(sycl::queue& sycl_queue) {
571571
}
572572

573573
template <oneapi::math::dft::precision precision, oneapi::math::dft::domain domain>
574-
inline void change_queue_causes_wait(sycl::queue& busy_queue) {
574+
inline void change_queue_causes_wait([[maybe_unused]] sycl::queue& busy_queue) {
575+
// Skip this test in AdaptiveCpp, which doesn't support host_task
576+
#if !defined(__ADAPTIVECPP__) && !defined(__HIPSYCL__)
575577
// create a queue with work on it, and then show that work is waited on when the descriptor
576578
// is committed to a new queue.
577579
// its possible to have a false positive result, but a false negative should not be possible.
@@ -616,6 +618,7 @@ inline void change_queue_causes_wait(sycl::queue& busy_queue) {
616618
// busy queue task has now completed.
617619
auto after_status = e.template get_info<sycl::info::event::command_execution_status>();
618620
ASSERT_EQ(after_status, sycl::info::event_command_status::complete);
621+
#endif
619622
}
620623

621624
template <oneapi::math::dft::precision precision, oneapi::math::dft::domain domain>

0 commit comments

Comments
 (0)