-
Notifications
You must be signed in to change notification settings - Fork 23
add broadcasting for remainder func #714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ab63fff
08b4d3d
38e1539
9bb76cb
0415b47
c85ea65
a4211b1
615f3e5
6582f7f
a99ae7d
c10d0a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -346,50 +346,80 @@ void dpnp_remainder_c(void* result_out, | |
const size_t input2_shape_ndim, | ||
const size_t* where) | ||
{ | ||
(void)input1_shape; | ||
(void)input1_shape_ndim; | ||
(void)input2_size; | ||
(void)input2_shape; | ||
(void)input2_shape_ndim; | ||
(void)where; | ||
|
||
cl::sycl::event event; | ||
_DataType_input1* input1 = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in)); | ||
_DataType_input2* input2 = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in)); | ||
if (!input1_size || !input2_size) | ||
{ | ||
return; | ||
} | ||
|
||
_DataType_input1* input1_data = reinterpret_cast<_DataType_input1*>(const_cast<void*>(input1_in)); | ||
_DataType_input2* input2_data = reinterpret_cast<_DataType_input2*>(const_cast<void*>(input2_in)); | ||
_DataType_output* result = reinterpret_cast<_DataType_output*>(result_out); | ||
|
||
if constexpr ((std::is_same<_DataType_input1, double>::value || std::is_same<_DataType_input1, float>::value) && | ||
std::is_same<_DataType_input2, _DataType_input1>::value) | ||
std::vector<size_t> result_shape = get_result_shape(input1_shape, input1_shape_ndim, | ||
input2_shape, input2_shape_ndim); | ||
|
||
DPNPC_id<_DataType_input1>* input1_it; | ||
const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>); | ||
input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(input1_it_size_in_bytes)); | ||
new (input1_it) DPNPC_id<_DataType_input1>(input1_data, input1_shape, input1_shape_ndim); | ||
|
||
input1_it->broadcast_to_shape(result_shape); | ||
|
||
DPNPC_id<_DataType_input2>* input2_it; | ||
const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>); | ||
input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(input2_it_size_in_bytes)); | ||
new (input2_it) DPNPC_id<_DataType_input2>(input2_data, input2_shape, input2_shape_ndim); | ||
|
||
input2_it->broadcast_to_shape(result_shape); | ||
|
||
const size_t result_size = input1_it->get_output_size(); | ||
|
||
cl::sycl::range<1> gws(result_size); | ||
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { | ||
const size_t i = global_id[0]; | ||
const _DataType_output input1_elem = (*input1_it)[i]; | ||
const _DataType_output input2_elem = (*input2_it)[i]; | ||
double fmod_res = cl::sycl::fmod((double)input1_elem, (double)input2_elem); | ||
double add = fmod_res + input2_elem; | ||
result[i] = cl::sycl::fmod(add, (double)input2_elem); | ||
|
||
}; | ||
auto kernel_func = [&](cl::sycl::handler& cgh) { | ||
cgh.parallel_for<class dpnp_remainder_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>( | ||
gws, kernel_parallel_for_func); | ||
}; | ||
|
||
cl::sycl::event event; | ||
|
||
if (input1_size == input2_size) | ||
{ | ||
event = oneapi::mkl::vm::fmod(DPNP_QUEUE, input1_size, input1, input2, result); | ||
event.wait(); | ||
event = oneapi::mkl::vm::add(DPNP_QUEUE, input1_size, result, input2, result); | ||
event.wait(); | ||
event = oneapi::mkl::vm::fmod(DPNP_QUEUE, input1_size, result, input2, result); | ||
if constexpr ((std::is_same<_DataType_input1, double>::value || | ||
std::is_same<_DataType_input1, float>::value) && | ||
std::is_same<_DataType_input2, _DataType_input1>::value) | ||
{ | ||
event = oneapi::mkl::vm::fmod(DPNP_QUEUE, input1_size, input1_data, input2_data, result); | ||
event.wait(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use dependent event. Do not wait for each. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use dependent event. Do not wait for each. |
||
event = oneapi::mkl::vm::add(DPNP_QUEUE, input1_size, result, input2_data, result); | ||
event.wait(); | ||
event = oneapi::mkl::vm::fmod(DPNP_QUEUE, input1_size, result, input2_data, result); | ||
} | ||
else | ||
{ | ||
event = DPNP_QUEUE.submit(kernel_func); | ||
} | ||
} | ||
else | ||
{ | ||
cl::sycl::range<1> gws(input1_size); | ||
auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { | ||
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/ | ||
{ | ||
_DataType_input1 input_elem1 = input1[i]; | ||
_DataType_input2 input_elem2 = input2[i]; | ||
double fmod = cl::sycl::fmod((double)input_elem1, (double)input_elem2); | ||
double add = fmod + input_elem2; | ||
result[i] = cl::sycl::fmod(add, (double)input_elem2); | ||
} | ||
}; | ||
|
||
auto kernel_func = [&](cl::sycl::handler& cgh) { | ||
cgh.parallel_for<class dpnp_remainder_c_kernel<_DataType_input1, _DataType_input2, _DataType_output>>( | ||
gws, kernel_parallel_for_func); | ||
}; | ||
|
||
event = DPNP_QUEUE.submit(kernel_func); | ||
} | ||
|
||
event.wait(); | ||
|
||
input1_it->~DPNPC_id(); | ||
input2_it->~DPNPC_id(); | ||
Comment on lines
+420
to
+421
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that we need to redesign this class, if explicitly call of the destructor is required for its interface. |
||
|
||
} | ||
|
||
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can call mem allocator at once here.