|
27 | 27 | #include <mkl_version.h> |
28 | 28 | #if INTEL_MKL_VERSION < 20250000 |
29 | 29 | #include <mkl/dfti.hpp> |
| 30 | +namespace oneapi::math::dft::mklgpu::detail { |
| 31 | +constexpr int committed = DFTI_COMMITTED; |
| 32 | +constexpr int uncommitted = DFTI_UNCOMMITTED; |
| 33 | +} // namespace oneapi::math::dft::mklgpu::detail |
30 | 34 | #else |
31 | 35 | #include <mkl/dft.hpp> |
| 36 | +namespace oneapi::math::dft::mklgpu::detail { |
| 37 | +constexpr auto committed = oneapi::mkl::dft::config_value::COMMITTED; |
| 38 | +constexpr auto uncommitted = oneapi::mkl::dft::config_value::UNCOMMITTED; |
| 39 | +} // namespace oneapi::math::dft::mklgpu::detail |
32 | 40 | #endif |
33 | 41 |
|
34 | 42 | namespace oneapi { |
@@ -57,120 +65,61 @@ inline constexpr oneapi::mkl::dft::precision to_mklgpu(dft::detail::precision do |
57 | 65 | } |
58 | 66 | } |
59 | 67 |
|
60 | | -/// Convert a config_param to equivalent backend native value. |
61 | | -inline constexpr oneapi::mkl::dft::config_param to_mklgpu(dft::detail::config_param param) { |
62 | | - using iparam = dft::detail::config_param; |
63 | | - using oparam = oneapi::mkl::dft::config_param; |
64 | | - switch (param) { |
65 | | - case iparam::FORWARD_DOMAIN: return oparam::FORWARD_DOMAIN; |
66 | | - case iparam::DIMENSION: return oparam::DIMENSION; |
67 | | - case iparam::LENGTHS: return oparam::LENGTHS; |
68 | | - case iparam::PRECISION: return oparam::PRECISION; |
69 | | - case iparam::FORWARD_SCALE: return oparam::FORWARD_SCALE; |
70 | | - case iparam::NUMBER_OF_TRANSFORMS: return oparam::NUMBER_OF_TRANSFORMS; |
71 | | - case iparam::COMPLEX_STORAGE: return oparam::COMPLEX_STORAGE; |
72 | | - case iparam::CONJUGATE_EVEN_STORAGE: return oparam::CONJUGATE_EVEN_STORAGE; |
73 | | - case iparam::FWD_DISTANCE: return oparam::FWD_DISTANCE; |
74 | | - case iparam::BWD_DISTANCE: return oparam::BWD_DISTANCE; |
75 | | - case iparam::WORKSPACE: return oparam::WORKSPACE; |
76 | | - case iparam::PACKED_FORMAT: return oparam::PACKED_FORMAT; |
77 | | - case iparam::WORKSPACE_PLACEMENT: return oparam::WORKSPACE; // Same as WORKSPACE |
78 | | - case iparam::WORKSPACE_EXTERNAL_BYTES: return oparam::WORKSPACE_BYTES; |
79 | | - case iparam::COMMIT_STATUS: return oparam::COMMIT_STATUS; |
80 | | - default: |
81 | | - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", |
82 | | - "Invalid config param."); |
83 | | - return static_cast<oparam>(0); |
84 | | - } |
85 | | -} |
| 68 | +template <dft::detail::config_param Param> |
| 69 | +struct to_mklgpu_impl; |
86 | 70 |
|
87 | 71 | /** Convert a config_value to the backend's native value. Throw on invalid input. |
88 | 72 | * @tparam Param The config param the value is for. |
89 | 73 | * @param value The config value to convert. |
90 | 74 | **/ |
91 | 75 | template <dft::detail::config_param Param> |
92 | | -inline constexpr int to_mklgpu(dft::detail::config_value value); |
93 | | - |
94 | | -template <> |
95 | | -inline constexpr int to_mklgpu<dft::detail::config_param::COMPLEX_STORAGE>( |
96 | | - dft::detail::config_value value) { |
97 | | - if (value == dft::detail::config_value::COMPLEX_COMPLEX) { |
98 | | - return DFTI_COMPLEX_COMPLEX; |
99 | | - } |
100 | | - else { |
101 | | - throw math::unimplemented("dft", "MKLGPU descriptor set_value()", |
102 | | - "MKLGPU only supports complex-complex for complex storage."); |
103 | | - return 0; |
104 | | - } |
105 | | -} |
106 | | - |
107 | | -template <> |
108 | | -inline constexpr int to_mklgpu<dft::detail::config_param::CONJUGATE_EVEN_STORAGE>( |
109 | | - dft::detail::config_value value) { |
110 | | - if (value == dft::detail::config_value::COMPLEX_COMPLEX) { |
111 | | - return DFTI_COMPLEX_COMPLEX; |
112 | | - } |
113 | | - else { |
114 | | - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", |
115 | | - "Invalid config value for conjugate even storage."); |
116 | | - return 0; |
117 | | - } |
| 76 | +inline constexpr auto to_mklgpu(dft::detail::config_value value) { |
| 77 | + return to_mklgpu_impl<Param>{}(value); |
118 | 78 | } |
119 | 79 |
|
| 80 | +#if INTEL_MKL_VERSION < 20250000 |
120 | 81 | template <> |
121 | | -inline constexpr int to_mklgpu<dft::detail::config_param::PLACEMENT>( |
122 | | - dft::detail::config_value value) { |
123 | | - if (value == dft::detail::config_value::INPLACE) { |
124 | | - return DFTI_INPLACE; |
125 | | - } |
126 | | - else if (value == dft::detail::config_value::NOT_INPLACE) { |
127 | | - return DFTI_NOT_INPLACE; |
128 | | - } |
129 | | - else { |
130 | | - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", |
131 | | - "Invalid config value for inplace."); |
132 | | - return 0; |
133 | | - } |
134 | | -} |
135 | | - |
| 82 | +struct to_mklgpu_impl<dft::detail::config_param::PLACEMENT> { |
| 83 | + inline constexpr auto operator()(dft::detail::config_value value) -> int { |
| 84 | + switch (value) { |
| 85 | + case dft::detail::config_value::INPLACE: return DFTI_INPLACE; |
| 86 | + case dft::detail::config_value::NOT_INPLACE: return DFTI_NOT_INPLACE; |
| 87 | + default: |
| 88 | + throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", |
| 89 | + "Invalid config value for inplace."); |
| 90 | + } |
| 91 | + } |
| 92 | +}; |
| 93 | +#else |
136 | 94 | template <> |
137 | | -inline constexpr int to_mklgpu<dft::detail::config_param::PACKED_FORMAT>( |
138 | | - dft::detail::config_value value) { |
139 | | - if (value == dft::detail::config_value::CCE_FORMAT) { |
140 | | - return DFTI_CCE_FORMAT; |
141 | | - } |
142 | | - else { |
143 | | - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", |
144 | | - "Invalid config value for packed format."); |
145 | | - return 0; |
146 | | - } |
147 | | -} |
148 | | - |
149 | | -/** Convert a config_value to the backend's native value. Throw on invalid input. |
150 | | - * @tparam Param The config param the value is for. |
151 | | - * @param value The config value to convert. |
152 | | -**/ |
153 | | -template <dft::detail::config_param Param> |
154 | | -inline constexpr oneapi::mkl::dft::config_value to_mklgpu_config_value( |
155 | | - dft::detail::config_value value); |
| 95 | +struct to_mklgpu_impl<dft::detail::config_param::PLACEMENT> { |
| 96 | + inline constexpr auto operator()(dft::detail::config_value value) { |
| 97 | + switch (value) { |
| 98 | + case dft::detail::config_value::INPLACE: return oneapi::mkl::dft::config_value::INPLACE; |
| 99 | + case dft::detail::config_value::NOT_INPLACE: |
| 100 | + return oneapi::mkl::dft::config_value::NOT_INPLACE; |
| 101 | + default: |
| 102 | + throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", |
| 103 | + "Invalid config value for inplace."); |
| 104 | + } |
| 105 | + } |
| 106 | +}; |
| 107 | +#endif |
156 | 108 |
|
157 | 109 | template <> |
158 | | -inline constexpr oneapi::mkl::dft::config_value |
159 | | -to_mklgpu_config_value<dft::detail::config_param::WORKSPACE_PLACEMENT>( |
160 | | - dft::detail::config_value value) { |
161 | | - if (value == dft::detail::config_value::WORKSPACE_AUTOMATIC) { |
162 | | - // NB: oneapi::mkl::dft::config_value != dft::detail::config_value |
163 | | - return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL; |
164 | | - } |
165 | | - else if (value == dft::detail::config_value::WORKSPACE_EXTERNAL) { |
166 | | - return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL; |
167 | | - } |
168 | | - else { |
169 | | - throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", |
170 | | - "Invalid config value for workspace placement."); |
171 | | - return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL; |
172 | | - } |
173 | | -} |
| 110 | +struct to_mklgpu_impl<dft::detail::config_param::WORKSPACE_PLACEMENT> { |
| 111 | + inline constexpr auto operator()(dft::detail::config_value value) { |
| 112 | + switch (value) { |
| 113 | + case dft::detail::config_value::WORKSPACE_AUTOMATIC: |
| 114 | + return oneapi::mkl::dft::config_value::WORKSPACE_INTERNAL; |
| 115 | + case dft::detail::config_value::WORKSPACE_EXTERNAL: |
| 116 | + return oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL; |
| 117 | + default: |
| 118 | + throw math::invalid_argument("dft", "MKLGPU descriptor set_value()", |
| 119 | + "Invalid config value for inplace."); |
| 120 | + } |
| 121 | + } |
| 122 | +}; |
174 | 123 | } // namespace detail |
175 | 124 | } // namespace mklgpu |
176 | 125 | } // namespace dft |
|
0 commit comments