11// *****************************************************************************
2- // Copyright (c) 2016-2020 , Intel Corporation
2+ // Copyright (c) 2016-2023 , Intel Corporation
33// All rights reserved.
44//
55// Redistribution and use in source and binary forms, with or without
@@ -114,10 +114,10 @@ DPCTLSyclEventRef (*dpnp_around_ext_c)(DPCTLSyclQueueRef,
114114 const int ,
115115 const DPCTLEventVectorRef) = dpnp_around_c<_DataType>;
116116
117- template <typename _KernelNameSpecialization >
117+ template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2 >
118118class dpnp_elemwise_absolute_c_kernel ;
119119
120- template <typename _DataType >
120+ template <typename _DataType_input, typename _DataType_output >
121121DPCTLSyclEventRef dpnp_elemwise_absolute_c (DPCTLSyclQueueRef q_ref,
122122 const void * input1_in,
123123 void * result1,
@@ -137,43 +137,63 @@ DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref,
137137 sycl::queue q = *(reinterpret_cast <sycl::queue*>(q_ref));
138138 sycl::event event;
139139
140- DPNPC_ptr_adapter<_DataType> input1_ptr (q_ref, input1_in, size);
141- _DataType* array1 = input1_ptr.get_ptr ();
142- DPNPC_ptr_adapter<_DataType> result1_ptr (q_ref, result1, size, false , true );
143- _DataType* result = result1_ptr.get_ptr ();
140+ _DataType_input* array1 = static_cast <_DataType_input*>(const_cast <void *>(input1_in));
141+ _DataType_output* result = static_cast <_DataType_output*>(result1);
144142
145- if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value )
143+ if constexpr (is_any_v<_DataType_input, float , double , std::complex < float >, std:: complex < double >> )
146144 {
147- // https://docs.oneapi.com/versions/latest/onemkl/abs.html
148145 event = oneapi::mkl::vm::abs (q, size, array1, result);
149146 }
150147 else
151148 {
152- sycl::range<1 > gws (size);
153- auto kernel_parallel_for_func = [=](sycl::id<1 > global_id) {
154- const size_t idx = global_id[0 ];
149+ static_assert (is_any_v<_DataType_input, int32_t , int64_t >,
150+ " Integer types are only expected to pass in 'abs' kernel" );
151+ static_assert (std::is_same_v<_DataType_input, _DataType_output>, " Result type must match a type of input data" );
152+
153+ constexpr size_t lws = 64 ;
154+ constexpr unsigned int vec_sz = 8 ;
155+ constexpr sycl::access::address_space global_space = sycl::access::address_space::global_space;
156+
157+ auto gws_range = sycl::range<1 >(((size + lws * vec_sz - 1 ) / (lws * vec_sz)) * lws);
158+ auto lws_range = sycl::range<1 >(lws);
155159
156- if (array1[idx] >= 0 )
160+ auto kernel_parallel_for_func = [=](sycl::nd_item<1 > nd_it) {
161+ auto sg = nd_it.get_sub_group ();
162+ const auto max_sg_size = sg.get_max_local_range ()[0 ];
163+ const size_t start =
164+ vec_sz * (nd_it.get_group (0 ) * nd_it.get_local_range (0 ) + sg.get_group_id ()[0 ] * max_sg_size);
165+
166+ if (start + static_cast <size_t >(vec_sz) * max_sg_size < size)
157167 {
158- result[idx] = array1[idx];
168+ using input_ptrT = sycl::multi_ptr<_DataType_input, global_space>;
169+ using result_ptrT = sycl::multi_ptr<_DataType_output, global_space>;
170+
171+ sycl::vec<_DataType_input, vec_sz> data_vec = sg.load <vec_sz>(input_ptrT (&array1[start]));
172+
173+ // sycl::abs() returns unsigned integers only, so explicit casting to signed ones is required
174+ using result_absT = typename cl::sycl::detail::make_unsigned<_DataType_output>::type;
175+ sycl::vec<_DataType_output, vec_sz> res_vec =
176+ dpnp_vec_cast<_DataType_output, result_absT, vec_sz>(sycl::abs (data_vec));
177+
178+ sg.store <vec_sz>(result_ptrT (&result[start]), res_vec);
159179 }
160180 else
161181 {
162- result[idx] = -1 * array1[idx];
182+ for (size_t k = start + sg.get_local_id ()[0 ]; k < size; k += max_sg_size)
183+ {
184+ result[k] = std::abs (array1[k]);
185+ }
163186 }
164187 };
165188
166189 auto kernel_func = [&](sycl::handler& cgh) {
167- cgh.parallel_for <class dpnp_elemwise_absolute_c_kernel <_DataType>>(gws, kernel_parallel_for_func);
190+ cgh.parallel_for <class dpnp_elemwise_absolute_c_kernel <_DataType_input, _DataType_output>>(
191+ sycl::nd_range<1 >(gws_range, lws_range), kernel_parallel_for_func);
168192 };
169-
170193 event = q.submit (kernel_func);
171194 }
172195
173- input1_ptr.depends_on (event);
174- result1_ptr.depends_on (event);
175196 event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
176-
177197 return DPCTLEvent_Copy (event_ref);
178198}
179199
@@ -182,28 +202,24 @@ void dpnp_elemwise_absolute_c(const void* input1_in, void* result1, size_t size)
182202{
183203 DPCTLSyclQueueRef q_ref = reinterpret_cast <DPCTLSyclQueueRef>(&DPNP_QUEUE);
184204 DPCTLEventVectorRef dep_event_vec_ref = nullptr ;
185- DPCTLSyclEventRef event_ref = dpnp_elemwise_absolute_c<_DataType>(q_ref,
186- input1_in,
187- result1,
188- size,
189- dep_event_vec_ref);
205+ DPCTLSyclEventRef event_ref = dpnp_elemwise_absolute_c<_DataType, _DataType >(q_ref,
206+ input1_in,
207+ result1,
208+ size,
209+ dep_event_vec_ref);
190210 DPCTLEvent_WaitAndThrow (event_ref);
211+ DPCTLEvent_Delete (event_ref);
191212}
192213
193214template <typename _DataType>
194215void (*dpnp_elemwise_absolute_default_c)(const void *, void *, size_t ) = dpnp_elemwise_absolute_c<_DataType>;
195216
196- template <typename _DataType >
217+ template <typename _DataType_input, typename _DataType_output = _DataType_input >
197218DPCTLSyclEventRef (*dpnp_elemwise_absolute_ext_c)(DPCTLSyclQueueRef,
198219 const void *,
199220 void *,
200221 size_t ,
201- const DPCTLEventVectorRef) = dpnp_elemwise_absolute_c<_DataType>;
202-
203- // template void dpnp_elemwise_absolute_c<double>(void* array1_in, void* result1, size_t size);
204- // template void dpnp_elemwise_absolute_c<float>(void* array1_in, void* result1, size_t size);
205- // template void dpnp_elemwise_absolute_c<long>(void* array1_in, void* result1, size_t size);
206- // template void dpnp_elemwise_absolute_c<int>(void* array1_in, void* result1, size_t size);
222+ const DPCTLEventVectorRef) = dpnp_elemwise_absolute_c<_DataType_input, _DataType_output>;
207223
208224template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
209225DPCTLSyclEventRef dpnp_cross_c (DPCTLSyclQueueRef q_ref,
@@ -1085,10 +1101,12 @@ void func_map_init_mathematical(func_map_t& fmap)
10851101 (void *)dpnp_elemwise_absolute_ext_c<int32_t >};
10861102 fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_LNG][eft_LNG] = {eft_LNG,
10871103 (void *)dpnp_elemwise_absolute_ext_c<int64_t >};
1088- fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_FLT][eft_FLT] = {eft_FLT,
1089- (void *)dpnp_elemwise_absolute_ext_c<float >};
1090- fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_DBL][eft_DBL] = {eft_DBL,
1091- (void *)dpnp_elemwise_absolute_ext_c<double >};
1104+ fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void *)dpnp_elemwise_absolute_ext_c<float >};
1105+ fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void *)dpnp_elemwise_absolute_ext_c<double >};
1106+ fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C64][eft_C64] = {
1107+ eft_FLT, (void *)dpnp_elemwise_absolute_ext_c<std::complex <float >, float >};
1108+ fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C128][eft_C128] = {
1109+ eft_DBL, (void *)dpnp_elemwise_absolute_ext_c<std::complex <double >, double >};
10921110
10931111 fmap[DPNPFuncName::DPNP_FN_AROUND][eft_INT][eft_INT] = {eft_INT, (void *)dpnp_around_default_c<int32_t >};
10941112 fmap[DPNPFuncName::DPNP_FN_AROUND][eft_LNG][eft_LNG] = {eft_LNG, (void *)dpnp_around_default_c<int64_t >};
0 commit comments