diff --git a/src/kokkos_ext/ArborX_DetailsKokkosExtStdAlgorithms.hpp b/src/kokkos_ext/ArborX_DetailsKokkosExtStdAlgorithms.hpp index e3da1c8ea..d24065279 100644 --- a/src/kokkos_ext/ArborX_DetailsKokkosExtStdAlgorithms.hpp +++ b/src/kokkos_ext/ArborX_DetailsKokkosExtStdAlgorithms.hpp @@ -13,6 +13,7 @@ #define ARBORX_DETAILS_KOKKOS_EXT_STD_ALGORITHMS_HPP #include +#include #include #include @@ -149,6 +150,49 @@ void iota(ExecutionSpace const &space, ViewType const &v, KOKKOS_LAMBDA(int i) { v(i) = value + (ValueType)i; }); } +template +KOKKOS_FUNCTION void nth_element(Iterator first, Iterator nth, Iterator last) +{ + if (first == last || nth == last) + return; + + // Lomuto partitioning + auto partition = [](Iterator left, Iterator right, Iterator pivot) { + --right; + + KokkosExt::swap(*pivot, *right); + auto it_i = left; + auto it_j = left; + while (it_j < right) + { + if (*it_j < *right) + KokkosExt::swap(*it_j, *(it_i++)); + ++it_j; + } + KokkosExt::swap(*it_i, *right); + return it_i; + }; + + // Simple quickselect implementation + while (true) + { + if (first == last) + return; + + // Choosing nth element as a pivot should lead to early exit if the array is + // sorted + auto pivot = partition(first, last, nth); + + if (pivot == nth) + return; + + if (nth < pivot) + last = pivot; + else + first = pivot + 1; + } +} + } // namespace ArborX::Details::KokkosExt #endif diff --git a/test/tstDetailsKokkosExtStdAlgorithms.cpp b/test/tstDetailsKokkosExtStdAlgorithms.cpp index 62c498413..f584de109 100644 --- a/test/tstDetailsKokkosExtStdAlgorithms.cpp +++ b/test/tstDetailsKokkosExtStdAlgorithms.cpp @@ -12,6 +12,7 @@ #include "ArborX_EnableDeviceTypes.hpp" // ARBORX_DEVICE_TYPES #include "ArborX_EnableViewComparison.hpp" #include +#include #include #include @@ -167,3 +168,42 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(adjacent_difference, DeviceType, Kokkos::resize(x, 5); BOOST_CHECK_THROW(adjacent_difference(space, y, x), ArborX::SearchException); } + +BOOST_AUTO_TEST_CASE_TEMPLATE(nth_element, DeviceType, ARBORX_DEVICE_TYPES) +{ + using ExecutionSpace = typename DeviceType::execution_space; + ExecutionSpace space; + + using ArborX::Details::KokkosExt::nth_element; + + for (auto v_ref : {std::vector{}, std::vector{0.5f}, + std::vector{0.1f, 0.1f, 0.1f}, + std::vector{0.1f, 0.2f, 0.3f}, + std::vector{0.1f, 0.3f, -0.5f, 1.0f, -0.9f, -1.2f}}) + { + int const n = v_ref.size(); + + Kokkos::View v("v", n); + Kokkos::deep_copy( + space, v, + Kokkos::View>(v_ref.data(), n)); + + Kokkos::View nth("nth", n); + for (int i = 0; i < n; ++i) + { + auto v_copy = ArborX::Details::KokkosExt::clone(space, v); + Kokkos::parallel_for( + Kokkos::RangePolicy(space, 0, 1), KOKKOS_LAMBDA(int) { + nth_element(v_copy.data(), v_copy.data() + i, v_copy.data() + n); + nth(i) = v_copy(i); + }); + } + space.fence(); + + auto nth_host = + Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, nth); + std::sort(v_ref.begin(), v_ref.end()); + BOOST_TEST(nth_host == v_ref, tt::per_element()); + } +}