diff --git a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp index 96d366e63c..8130fde96a 100644 --- a/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp @@ -55,15 +55,12 @@ template struct boolean_predicate } }; -template +template struct all_reduce_wg_contig { - void operator()(sycl::nd_item &ndit, + void operator()(sycl::nd_item<1> &ndit, outT *out, - size_t &out_idx, + const size_t &out_idx, const inpT *start, const inpT *end) const { @@ -82,15 +79,12 @@ struct all_reduce_wg_contig } }; -template +template struct any_reduce_wg_contig { - void operator()(sycl::nd_item &ndit, + void operator()(sycl::nd_item<1> &ndit, outT *out, - size_t &out_idx, + const size_t &out_idx, const inpT *start, const inpT *end) const { @@ -109,9 +103,9 @@ struct any_reduce_wg_contig } }; -template struct all_reduce_wg_strided +template struct all_reduce_wg_strided { - void operator()(sycl::nd_item &ndit, + void operator()(sycl::nd_item<1> &ndit, T *out, const size_t &out_idx, const T &local_val) const @@ -129,9 +123,9 @@ template struct all_reduce_wg_strided } }; -template struct any_reduce_wg_strided +template struct any_reduce_wg_strided { - void operator()(sycl::nd_item &ndit, + void operator()(sycl::nd_item<1> &ndit, T *out, const size_t &out_idx, const T &local_val) const @@ -215,6 +209,7 @@ struct ContigBooleanReduction outT *out_ = nullptr; GroupOp group_op_; size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; size_t reductions_per_wi = 16; public: @@ -222,28 +217,38 @@ struct ContigBooleanReduction outT *res, GroupOp group_op, size_t reduction_size, + size_t iteration_size, size_t reduction_size_per_wi) : inp_(inp), out_(res), group_op_(group_op), - reduction_max_gid_(reduction_size), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) { } - void operator()(sycl::nd_item<2> it) const + void operator()(sycl::nd_item<1> it) const { - - size_t reduction_id = it.get_group(0); - size_t reduction_batch_id = it.get_group(1); - size_t wg_size = it.get_local_range(1); - - size_t base = reduction_id * reduction_max_gid_; - size_t start = base + reduction_batch_id * wg_size * reductions_per_wi; - size_t end = std::min((start + (reductions_per_wi * wg_size)), - base + reduction_max_gid_); + const size_t red_gws_ = it.get_global_range(0) / iter_gws_; + const size_t reduction_id = it.get_global_id(0) / red_gws_; + const size_t reduction_batch_id = get_reduction_batch_id(it); + const size_t wg_size = it.get_local_range(0); + + const size_t base = reduction_id * reduction_max_gid_; + const size_t start = + base + reduction_batch_id * wg_size * reductions_per_wi; + const size_t end = std::min((start + (reductions_per_wi * wg_size)), + base + reduction_max_gid_); // reduction and atomic operations are performed // in group_op_ group_op_(it, out_, reduction_id, inp_ + start, inp_ + end); } + +private: + size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const + { + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups; + return reduction_batch_id; + } }; typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)( @@ -332,7 +337,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q, red_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on(init_ev); - constexpr std::uint8_t group_dim = 2; + constexpr std::uint8_t dim = 1; constexpr size_t preferred_reductions_per_wi = 4; size_t reductions_per_wi = @@ -344,15 +349,14 @@ boolean_reduction_contig_impl(sycl::queue exec_q, (reduction_nelems + reductions_per_wi * wg - 1) / (reductions_per_wi * wg); - auto gws = - sycl::range{iter_nelems, reduction_groups * wg}; - auto lws = sycl::range{1, wg}; + auto gws = sycl::range{iter_nelems * reduction_groups * wg}; + auto lws = sycl::range{wg}; cgh.parallel_for< class boolean_reduction_contig_krn>( - sycl::nd_range(gws, lws), + sycl::nd_range(gws, lws), ContigBooleanReduction( - arg_tp, res_tp, GroupOpT(), reduction_nelems, + arg_tp, res_tp, GroupOpT(), reduction_nelems, iter_nelems, reductions_per_wi)); }); } @@ -404,6 +408,7 @@ struct StridedBooleanReduction InputOutputIterIndexerT inp_out_iter_indexer_; InputRedIndexerT inp_reduced_dims_indexer_; size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; size_t reductions_per_wi = 16; public: @@ -415,23 +420,24 @@ struct StridedBooleanReduction InputOutputIterIndexerT arg_res_iter_indexer, InputRedIndexerT arg_reduced_dims_indexer, size_t reduction_size, + size_t iteration_size, size_t reduction_size_per_wi) : inp_(inp), out_(res), reduction_op_(reduction_op), group_op_(group_op), identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - reduction_max_gid_(reduction_size), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) { } - void operator()(sycl::nd_item<2> it) const + void operator()(sycl::nd_item<1> it) const { - - size_t reduction_id = it.get_group(0); - size_t reduction_batch_id = it.get_group(1); - size_t reduction_lid = it.get_local_id(1); - size_t wg_size = it.get_local_range(1); + const size_t red_gws_ = it.get_global_range(0) / iter_gws_; + const size_t reduction_id = it.get_global_id(0) / red_gws_; + const size_t reduction_batch_id = get_reduction_batch_id(it); + const size_t reduction_lid = it.get_local_id(0); + const size_t wg_size = it.get_local_range(0); auto inp_out_iter_offsets_ = inp_out_iter_indexer_(reduction_id); const py::ssize_t &inp_iter_offset = @@ -442,26 +448,34 @@ struct StridedBooleanReduction outT local_red_val(identity_); size_t arg_reduce_gid0 = reduction_lid + reduction_batch_id * wg_size * reductions_per_wi; - for (size_t m = 0; m < reductions_per_wi; ++m) { - size_t arg_reduce_gid = arg_reduce_gid0 + m * wg_size; - - if (arg_reduce_gid < reduction_max_gid_) { - py::ssize_t inp_reduction_offset = static_cast( - inp_reduced_dims_indexer_(arg_reduce_gid)); - py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg_size); + for (size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg_size) + { + py::ssize_t inp_reduction_offset = static_cast( + inp_reduced_dims_indexer_(arg_reduce_gid)); + py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; - // must convert to boolean first to handle nans - using dpctl::tensor::type_utils::convert_impl; - bool val = convert_impl(inp_[inp_offset]); - ReductionOp op = reduction_op_; + // must convert to boolean first to handle nans + using dpctl::tensor::type_utils::convert_impl; + bool val = convert_impl(inp_[inp_offset]); + ReductionOp op = reduction_op_; - local_red_val = op(local_red_val, static_cast(val)); - } + local_red_val = op(local_red_val, static_cast(val)); } // reduction and atomic operations are performed // in group_op_ group_op_(it, out_, out_iter_offset, local_red_val); } + +private: + size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const + { + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups; + return reduction_batch_id; + } }; template {iter_nelems, reduction_groups * wg}; - auto lws = sycl::range{1, wg}; + auto gws = sycl::range{iter_nelems * reduction_groups * wg}; + auto lws = sycl::range{wg}; cgh.parallel_for>( - sycl::nd_range(gws, lws), + sycl::nd_range(gws, lws), StridedBooleanReduction( arg_tp, res_tp, RedOpT(), GroupOpT(), identity_val, in_out_iter_indexer, reduction_indexer, reduction_nelems, - reductions_per_wi)); + iter_nelems, reductions_per_wi)); }); } return red_ev; diff --git a/dpctl/tensor/libtensor/include/kernels/reductions.hpp b/dpctl/tensor/libtensor/include/kernels/reductions.hpp index cab8e85540..c8aae0a3b9 100644 --- a/dpctl/tensor/libtensor/include/kernels/reductions.hpp +++ b/dpctl/tensor/libtensor/include/kernels/reductions.hpp @@ -122,6 +122,7 @@ struct ReductionOverGroupWithAtomicFunctor InputOutputIterIndexerT inp_out_iter_indexer_; InputRedIndexerT inp_reduced_dims_indexer_; size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; size_t reductions_per_wi = 16; public: @@ -133,22 +134,23 @@ struct ReductionOverGroupWithAtomicFunctor InputOutputIterIndexerT arg_res_iter_indexer, InputRedIndexerT arg_reduced_dims_indexer, size_t reduction_size, + size_t iteration_size, size_t reduction_size_per_wi) : inp_(data), out_(res), reduction_op_(reduction_op), identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - reduction_max_gid_(reduction_size), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) { } - void operator()(sycl::nd_item<2> it) const + void operator()(sycl::nd_item<1> it) const { - - size_t iter_gid = it.get_global_id(0); - size_t reduction_batch_id = it.get_group(1); - size_t reduction_lid = it.get_local_id(1); - size_t wg = it.get_local_range(1); // 0 <= reduction_lid < wg + const size_t red_gws_ = it.get_global_range(0) / iter_gws_; + const size_t iter_gid = it.get_global_id(0) / red_gws_; + const size_t reduction_batch_id = get_reduction_batch_id(it); + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg // work-items sums over input with indices // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg @@ -202,6 +204,14 @@ struct ReductionOverGroupWithAtomicFunctor } } } + +private: + size_t get_reduction_batch_id(sycl::nd_item<1> const &it) const + { + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups; + return reduction_batch_id; + } }; typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)( @@ -222,6 +232,9 @@ typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)( template class sum_reduction_over_group_with_atomics_krn; +template +class sum_reduction_over_group_with_atomics_init_krn; + template class sum_reduction_seq_strided_krn; @@ -295,13 +308,16 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( iter_shape_and_strides + 2 * iter_nd; IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, res_strides); - + using InitKernelName = + class sum_reduction_over_group_with_atomics_init_krn; cgh.depends_on(depends); - cgh.parallel_for(sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { - auto res_offset = res_indexer(id[0]); - res_tp[res_offset] = identity_val; - }); + cgh.parallel_for( + sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); }); sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { @@ -343,21 +359,21 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl( } auto globalRange = - sycl::range<2>{iter_nelems, reduction_groups * wg}; - auto localRange = sycl::range<2>{1, wg}; + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; using KernelName = class sum_reduction_over_group_with_atomics_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( - sycl::nd_range<2>(globalRange, localRange), + sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupWithAtomicFunctor( arg_tp, res_tp, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, reduction_nelems, - reductions_per_wi)); + iter_nelems, reductions_per_wi)); }); return comp_ev; @@ -480,21 +496,21 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl( } auto globalRange = - sycl::range<2>{iter_nelems, reduction_groups * wg}; - auto localRange = sycl::range<2>{1, wg}; + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; using KernelName = class sum_reduction_over_group_with_atomics_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( - sycl::nd_range<2>(globalRange, localRange), + sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupWithAtomicFunctor( arg_tp, res_tp, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, reduction_nelems, - reductions_per_wi)); + iter_nelems, reductions_per_wi)); }); return comp_ev; @@ -518,6 +534,7 @@ struct ReductionOverGroupNoAtomicFunctor InputOutputIterIndexerT inp_out_iter_indexer_; InputRedIndexerT inp_reduced_dims_indexer_; size_t reduction_max_gid_ = 0; + size_t iter_gws_ = 1; size_t reductions_per_wi = 16; public: @@ -529,22 +546,25 @@ struct ReductionOverGroupNoAtomicFunctor InputOutputIterIndexerT arg_res_iter_indexer, InputRedIndexerT arg_reduced_dims_indexer, size_t reduction_size, + size_t iteration_size, size_t reduction_size_per_wi) : inp_(data), out_(res), reduction_op_(reduction_op), identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), inp_reduced_dims_indexer_(arg_reduced_dims_indexer), - reduction_max_gid_(reduction_size), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) { } - void operator()(sycl::nd_item<2> it) const + void operator()(sycl::nd_item<1> it) const { - size_t iter_gid = it.get_global_id(0); - size_t reduction_batch_id = it.get_group(1); - size_t reduction_lid = it.get_local_id(1); - size_t wg = it.get_local_range(1); // 0 <= reduction_lid < wg + const size_t red_gws_ = it.get_global_range(0) / iter_gws_; + const size_t iter_gid = it.get_global_id(0) / red_gws_; + const size_t n_reduction_groups = it.get_group_range(0) / iter_gws_; + const size_t reduction_batch_id = it.get_group(0) % n_reduction_groups; + const size_t reduction_lid = it.get_local_id(0); + const size_t wg = it.get_local_range(0); // 0 <= reduction_lid < wg // work-items sums over input with indices // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg @@ -580,7 +600,7 @@ struct ReductionOverGroupNoAtomicFunctor if (work_group.leader()) { // each group writes to a different memory location - out_[out_iter_offset * it.get_group_range(1) + reduction_batch_id] = + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = red_val_over_wg; } } @@ -647,20 +667,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl( assert(reduction_groups == 1); auto globalRange = - sycl::range<2>{iter_nelems, reduction_groups * wg}; - auto localRange = sycl::range<2>{1, wg}; + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; using KernelName = class sum_reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( - sycl::nd_range<2>(globalRange, localRange), + sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupNoAtomicFunctor( arg_tp, res_tp, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, reduction_nelems, - reductions_per_wi)); + iter_nelems, reductions_per_wi)); }); return comp_ev; @@ -713,20 +733,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl( reduction_shape_stride}; auto globalRange = - sycl::range<2>{iter_nelems, reduction_groups * wg}; - auto localRange = sycl::range<2>{1, wg}; + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; using KernelName = class sum_reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( - sycl::nd_range<2>(globalRange, localRange), + sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupNoAtomicFunctor( arg_tp, partially_reduced_tmp, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, reduction_nelems, - preferrered_reductions_per_wi)); + iter_nelems, preferrered_reductions_per_wi)); }); size_t remaining_reduction_nelems = reduction_groups; @@ -768,20 +788,20 @@ sycl::event sum_reduction_over_group_temps_strided_impl( ReductionIndexerT reduction_indexer{}; auto globalRange = - sycl::range<2>{iter_nelems, reduction_groups_ * wg}; - auto localRange = sycl::range<2>{1, wg}; + sycl::range<1>{iter_nelems * reduction_groups_ * wg}; + auto localRange = sycl::range<1>{wg}; using KernelName = class sum_reduction_over_group_temps_krn< resTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( - sycl::nd_range<2>(globalRange, localRange), + sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupNoAtomicFunctor< resTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>( temp_arg, temp2_arg, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, - remaining_reduction_nelems, + remaining_reduction_nelems, iter_nelems, preferrered_reductions_per_wi)); }); @@ -824,20 +844,21 @@ sycl::event sum_reduction_over_group_temps_strided_impl( assert(reduction_groups == 1); auto globalRange = - sycl::range<2>{iter_nelems, reduction_groups * wg}; - auto localRange = sycl::range<2>{1, wg}; + sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; using KernelName = class sum_reduction_over_group_temps_krn< argTy, resTy, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT>; cgh.parallel_for( - sycl::nd_range<2>(globalRange, localRange), + sycl::nd_range<1>(globalRange, localRange), ReductionOverGroupNoAtomicFunctor( temp_arg, res_tp, ReductionOpT(), identity_val, in_out_iter_indexer, reduction_indexer, - remaining_reduction_nelems, reductions_per_wi)); + remaining_reduction_nelems, iter_nelems, + reductions_per_wi)); }); sycl::event cleanup_host_task_event = diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index 76230fc655..fc2a0ec8de 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -156,3 +156,19 @@ def test_sum_keepdims_zero_size(): a0 = a[0] s5 = dpt.sum(a0, keepdims=True) assert s5.shape == (1, 1) + + +@pytest.mark.parametrize("arg_dtype", ["i8", "f4", "c8"]) +@pytest.mark.parametrize("n", [1023, 1024, 1025]) +def test_largish_reduction(arg_dtype, n): + q = get_queue_or_skip() + skip_if_dtype_not_supported(arg_dtype, q) + + m = 5 + x = dpt.ones((m, n, m), dtype=arg_dtype) + + y1 = dpt.sum(x, axis=(0, 1)) + y2 = dpt.sum(x, axis=(1, 2)) + + assert dpt.all(dpt.equal(y1, y2)) + assert dpt.all(dpt.equal(y1, n * m))