Skip to content

Commit

Permalink
Get rid of ranks in filterResults
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed May 3, 2024
1 parent 0bc2c62 commit 3279a65
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
3 changes: 1 addition & 2 deletions src/details/ArborX_DetailsDistributedTreeNearest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,7 @@ void DistributedTreeImpl::phaseII(ExecutionSpace const &space, Tree const &tree,
distances(i) = out(i).distance;
});

DistributedTree::filterResults(space, predicates, distances, values, offset,
ranks);
DistributedTree::filterResults(space, predicates, distances, values, offset);
}

template <typename Tree, typename ExecutionSpace, typename Predicates,
Expand Down
53 changes: 23 additions & 30 deletions src/details/ArborX_DetailsDistributedTreeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,10 @@ void forwardQueriesAndCommunicateResults(
}

template <typename ExecutionSpace, typename MemorySpace, typename Predicates,
typename Values, typename Offset, typename Ranks>
typename Values, typename Offset>
void filterResults(ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<float *, MemorySpace> const &distances,
Values &values, Offset &offset, Ranks &ranks)
Values &values, Offset &offset)
{
Kokkos::Profiling::ScopedRegion guard(
"ArborX::DistributedTree::filterResults");
Expand All @@ -407,60 +407,53 @@ void filterResults(ExecutionSpace const &space, Predicates const &queries,
int const n_truncated_results = KokkosExt::lastElement(space, new_offset);
Kokkos::View<Value *, MemorySpace> new_values(
Kokkos::view_alloc(space, values.label()), n_truncated_results);
Kokkos::View<int *, MemorySpace> new_ranks(
Kokkos::view_alloc(space, ranks.label()), n_truncated_results);

using PairValueRank = Kokkos::pair<Value, int>;
using PairValueRankDistance = Kokkos::pair<PairValueRank, float>;
using PairValueDistance = Kokkos::pair<Value, float>;
struct CompareDistance
{
KOKKOS_INLINE_FUNCTION bool operator()(PairValueRankDistance const &lhs,
PairValueRankDistance const &rhs)
KOKKOS_INLINE_FUNCTION bool operator()(PairValueDistance const &lhs,
PairValueDistance const &rhs)
{
// reverse order (larger distance means lower priority)
return lhs.second > rhs.second;
}
};

int const n_results = KokkosExt::lastElement(space, offset);
Kokkos::View<PairValueRankDistance *, MemorySpace> buffer(
Kokkos::View<PairValueDistance *, MemorySpace> buffer(
Kokkos::view_alloc(
space, Kokkos::WithoutInitializing,
"ArborX::DistributedTree::query::filterResults::buffer"),
n_results);
using PriorityQueue =
Details::PriorityQueue<PairValueRankDistance, CompareDistance,
UnmanagedStaticVector<PairValueRankDistance>>;
Details::PriorityQueue<PairValueDistance, CompareDistance,
UnmanagedStaticVector<PairValueDistance>>;

Kokkos::parallel_for(
"ArborX::DistributedTree::query::truncate_results",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
KOKKOS_LAMBDA(int q) {
if (offset(q + 1) > offset(q))
{
auto local_buffer = Kokkos::subview(
buffer, Kokkos::make_pair(offset(q), offset(q + 1)));
if (offset(q) == offset(q + 1))
return;

PriorityQueue queue(UnmanagedStaticVector<PairValueRankDistance>(
local_buffer.data(), local_buffer.size()));
auto local_buffer = Kokkos::subview(
buffer, Kokkos::make_pair(offset(q), offset(q + 1)));

for (int i = offset(q); i < offset(q + 1); ++i)
{
queue.emplace(PairValueRank{values(i), ranks(i)}, distances(i));
}
PriorityQueue queue(UnmanagedStaticVector<PairValueDistance>(
local_buffer.data(), local_buffer.size()));

int count = 0;
while (!queue.empty() && count < getK(queries(q)))
{
new_values(new_offset(q) + count) = queue.top().first.first;
new_ranks(new_offset(q) + count) = queue.top().first.second;
queue.pop();
++count;
}
for (int i = offset(q); i < offset(q + 1); ++i)
queue.emplace(values(i), distances(i));

int count = 0;
while (!queue.empty() && count < getK(queries(q)))
{
new_values(new_offset(q) + count) = queue.top().first;
queue.pop();
++count;
}
});
values = new_values;
ranks = new_ranks;
offset = new_offset;
}

Expand Down

0 comments on commit 3279a65

Please sign in to comment.