Skip to content

Commit

Permalink
Add nth_element function
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Apr 19, 2024
1 parent d264a9d commit 81c5741
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
44 changes: 44 additions & 0 deletions src/kokkos_ext/ArborX_DetailsKokkosExtStdAlgorithms.hpp
Expand Up @@ -13,6 +13,7 @@
#define ARBORX_DETAILS_KOKKOS_EXT_STD_ALGORITHMS_HPP

#include <ArborX_DetailsKokkosExtAccessibilityTraits.hpp>
#include <ArborX_DetailsKokkosExtSwap.hpp>
#include <ArborX_Exception.hpp>

#include <Kokkos_Core.hpp>
Expand Down Expand Up @@ -149,6 +150,49 @@ void iota(ExecutionSpace const &space, ViewType const &v,
KOKKOS_LAMBDA(int i) { v(i) = value + (ValueType)i; });
}

template <typename Iterator>
KOKKOS_FUNCTION void nth_element(Iterator first, Iterator nth, Iterator last)
{
if (first == last || nth == last)
return;

// Lomuto partitioning
auto partition = [](Iterator left, Iterator right, Iterator pivot) {
--right;

KokkosExt::swap(*pivot, *right);
auto it_i = left;
auto it_j = left;
while (it_j < right)
{
if (*it_j < *right)
KokkosExt::swap(*it_j, *(it_i++));
++it_j;
}
KokkosExt::swap(*it_i, *right);
return it_i;
};

// Simple quickselect implementation
while (true)
{
if (first == last)
return;

// Choosing nth element as a pivot should lead to early exit if the array is
// sorted
auto pivot = partition(first, last, nth);

if (pivot == nth)
return;

if (nth < pivot)
last = pivot;
else
first = pivot + 1;
}
}

} // namespace ArborX::Details::KokkosExt

#endif
40 changes: 40 additions & 0 deletions test/tstDetailsKokkosExtStdAlgorithms.cpp
Expand Up @@ -12,6 +12,7 @@
#include "ArborX_EnableDeviceTypes.hpp" // ARBORX_DEVICE_TYPES
#include "ArborX_EnableViewComparison.hpp"
#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp>
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
#include <ArborX_Exception.hpp>

#include <Kokkos_Core.hpp>
Expand Down Expand Up @@ -167,3 +168,42 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(adjacent_difference, DeviceType,
Kokkos::resize(x, 5);
BOOST_CHECK_THROW(adjacent_difference(space, y, x), ArborX::SearchException);
}

BOOST_AUTO_TEST_CASE_TEMPLATE(nth_element, DeviceType, ARBORX_DEVICE_TYPES)
{
using ExecutionSpace = typename DeviceType::execution_space;
ExecutionSpace space;

using ArborX::Details::KokkosExt::nth_element;

for (auto v_ref : {std::vector<float>{}, std::vector<float>{0.5f},
std::vector<float>{0.1f, 0.1f, 0.1f},
std::vector<float>{0.1f, 0.2f, 0.3f},
std::vector<float>{0.1f, 0.3f, -0.5f, 1.0f, -0.9f, -1.2f}})
{
int const n = v_ref.size();

Kokkos::View<float *, DeviceType> v("v", n);
Kokkos::deep_copy(
space, v,
Kokkos::View<float *, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>(v_ref.data(), n));

Kokkos::View<float *, DeviceType> nth("nth", n);
for (int i = 0; i < n; ++i)
{
auto v_copy = ArborX::Details::KokkosExt::clone(space, v);
Kokkos::parallel_for(
Kokkos::RangePolicy<ExecutionSpace>(space, 0, 1), KOKKOS_LAMBDA(int) {
nth_element(v_copy.data(), v_copy.data() + i, v_copy.data() + n);
nth(i) = v_copy(i);
});
}
space.fence();

auto nth_host =
Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, nth);
std::sort(v_ref.begin(), v_ref.end());
BOOST_TEST(nth_host == v_ref, tt::per_element());
}
}

0 comments on commit 81c5741

Please sign in to comment.