Skip to content

Commit

Permalink
Move out the common nearest buffer allocation part
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Apr 3, 2024
1 parent 668c9b7 commit c0deec3
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 93 deletions.
48 changes: 9 additions & 39 deletions src/details/ArborX_DetailsBruteForceImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <ArborX_DetailsKokkosExtMinMaxOperations.hpp>
#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp>
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
#include <ArborX_DetailsNearestBufferProvider.hpp>
#include <ArborX_DetailsPriorityQueue.hpp>
#include <ArborX_Exception.hpp>

Expand Down Expand Up @@ -155,37 +156,7 @@ struct BruteForceImpl
int const n_indexables = values.size();
int const n_predicates = predicates.size();

using Buffer = Kokkos::View<Kokkos::pair<int, float> *, MemorySpace>;
using Offset = Kokkos::View<int *, MemorySpace>;
struct BufferProvider
{
Buffer _buffer;
Offset _offset;

KOKKOS_FUNCTION auto operator()(int i) const
{
auto const *offset_ptr = &_offset(i);
return Kokkos::subview(
_buffer, Kokkos::make_pair(*offset_ptr, *(offset_ptr + 1)));
}
};

Offset offset(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::BruteForce::query::nearest::offset"),
n_predicates + 1);
Kokkos::parallel_for(
"ArborX::BruteForce::query::nearest::"
"scan_queries_for_numbers_of_neighbors",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_predicates),
KOKKOS_LAMBDA(int i) { offset(i) = getK(predicates(i)); });
KokkosExt::exclusive_scan(space, offset, offset, 0);
int const buffer_size = KokkosExt::lastElement(space, offset);

Buffer buffer(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::TreeTraversal::nearest::buffer"),
buffer_size);
BufferProvider buffer_provider{buffer, offset};
NearestBufferProvider<MemorySpace> buffer_provider(space, predicates);

Kokkos::parallel_for(
"ArborX::BruteForce::query::nearest::"
Expand All @@ -199,14 +170,8 @@ struct BruteForceImpl
if (k < 1)
return;

auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value;

using PairIndexDistance = Kokkos::pair<int, float>;
static_assert(std::is_same<typename decltype(buffer)::value_type,
PairIndexDistance>::value,
"Type of the elements stored in the buffer passed as "
"argument to "
"TreeTraversal::nearestQuery is not right");
using PairIndexDistance =
typename NearestBufferProvider<MemorySpace>::PairIndexDistance;
struct CompareDistance
{
KOKKOS_INLINE_FUNCTION bool
Expand All @@ -222,6 +187,11 @@ struct BruteForceImpl
heap(UnmanagedStaticVector<PairIndexDistance>(buffer.data(),
buffer.size()));

// Nodes with a distance that exceed that radius can safely be
// discarded. Initialize the radius to infinity and tighten it once k
// neighbors have been found.
auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value;

for (int j = 0; j < n_indexables; ++j)
{
auto const distance = predicate.distance(indexables(j));
Expand Down
73 changes: 73 additions & 0 deletions src/details/ArborX_DetailsNearestBufferProvider.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/****************************************************************************
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
* distributed under a BSD 3-clause license. For the licensing terms see *
* the LICENSE file in the top-level directory. *
* *
* SPDX-License-Identifier: BSD-3-Clause *
****************************************************************************/
#ifndef ARBORX_DETAILS_NEAREST_BUFFER_PROVIDER_HPP
#define ARBORX_DETAILS_NEAREST_BUFFER_PROVIDER_HPP

#include <ArborX_DetailsKokkosExtViewHelpers.hpp>

#include <Kokkos_Core.hpp>

namespace ArborX::Details
{

template <typename MemorySpace>
struct NearestBufferProvider
{
static_assert(Kokkos::is_memory_space_v<MemorySpace>);

using PairIndexDistance = Kokkos::pair<int, float>;

Kokkos::View<PairIndexDistance *, MemorySpace> _buffer;
Kokkos::View<int *, MemorySpace> _offset;

NearestBufferProvider() = default;

template <typename ExecutionSpace, typename Predicates>
NearestBufferProvider(ExecutionSpace const &space,
Predicates const &predicates)
: _buffer("ArborX::NearestBufferProvider::buffer", 0)
, _offset("ArborX::NearestBufferProvider::offset", 0)
{
allocateBuffer(space, predicates);
}

KOKKOS_FUNCTION auto operator()(int i) const
{
auto const *offset_ptr = &_offset(i);
return Kokkos::subview(_buffer,
Kokkos::make_pair(*offset_ptr, *(offset_ptr + 1)));
}

template <typename ExecutionSpace, typename Predicates>
void allocateBuffer(ExecutionSpace const &space, Predicates const &predicates)
{
auto const n_queries = predicates.size();

KokkosExt::reallocWithoutInitializing(space, _offset, n_queries + 1);

Kokkos::parallel_for(
"ArborX::NearestBufferProvider::scan_queries_for_numbers_of_neighbors",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
KOKKOS_CLASS_LAMBDA(int i) { _offset(i) = getK(predicates(i)); });
KokkosExt::exclusive_scan(space, _offset, _offset, 0);
int const buffer_size = KokkosExt::lastElement(space, _offset);
// Allocate buffer over which to perform heap operations in the nearest
// query to store nearest nodes found so far.
// It is not possible to anticipate how much memory to allocate since the
// number of nearest neighbors k is only known at runtime.

KokkosExt::reallocWithoutInitializing(space, _buffer, buffer_size);
}
};

} // namespace ArborX::Details

#endif
64 changes: 10 additions & 54 deletions src/details/ArborX_DetailsTreeTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <ArborX_DetailsKokkosExtArithmeticTraits.hpp>
#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp>
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
#include <ArborX_DetailsNearestBufferProvider.hpp>
#include <ArborX_DetailsNode.hpp> // ROPE_SENTINEL
#include <ArborX_DetailsPriorityQueue.hpp>
#include <ArborX_DetailsStack.hpp>
Expand Down Expand Up @@ -128,48 +129,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
Predicates _predicates;
Callback _callback;

using Buffer = Kokkos::View<Kokkos::pair<int, float> *, MemorySpace>;
using Offset = Kokkos::View<int *, MemorySpace>;
struct BufferProvider
{
Buffer _buffer;
Offset _offset;

KOKKOS_FUNCTION auto operator()(int i) const
{
auto const *offset_ptr = &_offset(i);
return Kokkos::subview(_buffer,
Kokkos::make_pair(*offset_ptr, *(offset_ptr + 1)));
}
};

BufferProvider _buffer;

template <typename ExecutionSpace>
void allocateBuffer(ExecutionSpace const &space)
{
auto const n_queries = _predicates.size();

Offset offset(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::TreeTraversal::nearest::offset"),
n_queries + 1);
Kokkos::parallel_for(
"ArborX::TreeTraversal::nearest::"
"scan_queries_for_numbers_of_neighbors",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
KOKKOS_CLASS_LAMBDA(int i) { offset(i) = getK(_predicates(i)); });
KokkosExt::exclusive_scan(space, offset, offset, 0);
int const buffer_size = KokkosExt::lastElement(space, offset);
// Allocate buffer over which to perform heap operations in
// TreeTraversal::nearestQuery() to store nearest leaf nodes found so far.
// It is not possible to anticipate how much memory to allocate since the
// number of nearest neighbors k is only known at runtime.

Buffer buffer(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::TreeTraversal::nearest::buffer"),
buffer_size);
_buffer = BufferProvider{buffer, offset};
}
NearestBufferProvider<MemorySpace> _buffer;

template <typename ExecutionSpace>
TreeTraversal(ExecutionSpace const &space, BVH const &bvh,
Expand All @@ -192,7 +152,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
}
else
{
allocateBuffer(space);
_buffer = NearestBufferProvider<MemorySpace>(space, predicates);

Kokkos::parallel_for(
"ArborX::TreeTraversal::nearest",
Expand Down Expand Up @@ -226,17 +186,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
if (k < 1)
return;

// Nodes with a distance that exceed that radius can safely be
// discarded. Initialize the radius to infinity and tighten it once k
// neighbors have been found.
auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value;

using PairIndexDistance = Kokkos::pair<int, float>;
static_assert(
std::is_same<typename decltype(buffer)::value_type,
PairIndexDistance>::value,
"Type of the elements stored in the buffer passed as argument to "
"TreeTraversal::nearestQuery is not right");
using PairIndexDistance =
typename NearestBufferProvider<MemorySpace>::PairIndexDistance;
struct CompareDistance
{
KOKKOS_INLINE_FUNCTION bool operator()(PairIndexDistance const &lhs,
Expand Down Expand Up @@ -281,6 +232,11 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
float distance_right = 0.f;
float distance_node = 0.f;

// Nodes with a distance that exceed that radius can safely be
// discarded. Initialize the radius to infinity and tighten it once k
// neighbors have been found.
auto radius = KokkosExt::ArithmeticTraits::infinity<float>::value;

do
{
bool traverse_left = false;
Expand Down
2 changes: 2 additions & 0 deletions test/tstKokkosToolsAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(bvh_bvh_allocations_prefixed, DeviceType,
void const * /*ptr*/, uint64_t /*size*/) {
std::regex re("^(Testing::"
"|ArborX::BVH::"
"|ArborX::NearestBufferProvider::"
"|ArborX::Sorting::"
"|Kokkos::SortImpl::BinSortFunctor::"
"|Kokkos::Serial::" // unsure what's going on
Expand Down Expand Up @@ -89,6 +90,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(bvh_query_allocations_prefixed, DeviceType,
void const * /*ptr*/, uint64_t /*size*/) {
std::regex re("^(Testing::"
"|ArborX::BVH::query::"
"|ArborX::NearestBufferProvider::"
"|ArborX::TreeTraversal::spatial::"
"|ArborX::TreeTraversal::nearest::"
"|ArborX::CrsGraphWrapper::"
Expand Down
2 changes: 2 additions & 0 deletions test/tstKokkosToolsDistributedAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(
std::regex re("^(Testing::"
"|ArborX::DistributedTree::"
"|ArborX::BVH::"
"|ArborX::NearestBufferProvider::"
"|ArborX::Sorting::"
").*");
BOOST_TEST(std::regex_match(label, re),
Expand Down Expand Up @@ -73,6 +74,7 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(
"|ArborX::DistributedTree::query::"
"|ArborX::Distributor::"
"|ArborX::BVH::query::"
"|ArborX::NearestBufferProvider::"
"|ArborX::TreeTraversal::spatial::"
"|ArborX::TreeTraversal::nearest::"
"|ArborX::CrsGraphWrapper::"
Expand Down

0 comments on commit c0deec3

Please sign in to comment.