Skip to content

Commit

Permalink
Merge pull request #900 from aprokop/8-update_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Jul 21, 2023
2 parents b68d11a + ad0d4c0 commit 95bd0c0
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 53 deletions.
3 changes: 2 additions & 1 deletion src/ArborX_BruteForce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ void BruteForce<MemorySpace, BoundingVolume>::query(
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
static_assert(std::is_same<Tag, Details::SpatialPredicateTag>{},
"nearest query not implemented yet");
Details::check_valid_callback(callback, predicates);
using Value = int;
Details::check_valid_callback<Value>(callback, predicates);

Kokkos::Profiling::pushRegion("ArborX::BruteForce::query::spatial");

Expand Down
54 changes: 47 additions & 7 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,51 @@ class BasicBoundingVolumeHierarchy
};

template <typename MemorySpace>
using BoundingVolumeHierarchy =
BasicBoundingVolumeHierarchy<MemorySpace, Details::PairIndexVolume<Box>,
Details::DefaultIndexableGetter, Box>;
class BoundingVolumeHierarchy
: public BasicBoundingVolumeHierarchy<MemorySpace,
Details::PairIndexVolume<Box>,
Details::DefaultIndexableGetter, Box>
{
using base_type =
BasicBoundingVolumeHierarchy<MemorySpace, Details::PairIndexVolume<Box>,
Details::DefaultIndexableGetter, Box>;

public:
using legacy_tree = void;

BoundingVolumeHierarchy() = default; // build an empty tree

template <typename ExecutionSpace, typename Primitives,
typename SpaceFillingCurve = Experimental::Morton64>
BoundingVolumeHierarchy(ExecutionSpace const &space,
Primitives const &primitives,
SpaceFillingCurve const &curve = SpaceFillingCurve())
: base_type(space, primitives, curve)
{}

template <typename ExecutionSpace, typename Predicates, typename Callback>
void query(ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback,
Experimental::TraversalPolicy const &policy =
Experimental::TraversalPolicy()) const
{
base_type::query(space, predicates,
Details::LegacyCallbackWrapper<
Callback, typename base_type::value_type>{callback},
policy);
}

template <typename ExecutionSpace, typename Predicates,
typename CallbackOrView, typename View, typename... Args>
std::enable_if_t<Kokkos::is_view<std::decay_t<View>>{}>
query(ExecutionSpace const &space, Predicates const &predicates,
CallbackOrView &&callback_or_view, View &&view, Args &&...args) const
{
ArborX::query(*this, space, predicates,
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}
};

template <typename MemorySpace>
using BVH = BoundingVolumeHierarchy<MemorySpace>;
Expand Down Expand Up @@ -228,7 +270,7 @@ void BasicBoundingVolumeHierarchy<
MemorySpace, Value, IndexableGetter,
BoundingVolume>::query(ExecutionSpace const &space,
Predicates const &predicates,
Callback const &legacy_callback,
Callback const &callback,
Experimental::TraversalPolicy const &policy) const
{
static_assert(
Expand All @@ -238,9 +280,7 @@ void BasicBoundingVolumeHierarchy<
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
Details::check_valid_callback(legacy_callback, predicates);
Details::LegacyCallbackWrapper<Callback, value_type> callback{
legacy_callback};
Details::check_valid_callback<value_type>(callback, predicates);

using Tag = typename Details::AccessTraitsHelper<Access>::tag;
std::string profiling_prefix = "ArborX::BVH::query::";
Expand Down
12 changes: 6 additions & 6 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -170,7 +170,7 @@ KOKKOS_INLINE_FUNCTION
return false;
}

template <typename Callback, typename Predicates>
template <typename Value, typename Callback, typename Predicates>
void check_valid_callback(Callback const &callback, Predicates const &)
{
check_generic_lambda_support(callback);
Expand All @@ -183,7 +183,7 @@ void check_valid_callback(Callback const &callback, Predicates const &)
"The predicate tag is not valid");

static_assert(Kokkos::is_detected<Experimental_CallbackArchetypeExpression,
Callback, Predicate, int>{},
Callback, Predicate, Value>{},
"Callback 'operator()' does not have the correct signature");

static_assert(
Expand All @@ -193,18 +193,18 @@ void check_valid_callback(Callback const &callback, Predicates const &)
(std::is_same<
CallbackTreeTraversalControl,
Kokkos::detected_t<Experimental_CallbackArchetypeExpression,
Callback, Predicate, int>>{} ||
Callback, Predicate, Value>>{} ||
std::is_void<
Kokkos::detected_t<Experimental_CallbackArchetypeExpression,
Callback, Predicate, int>>{}),
Callback, Predicate, Value>>{}),
"Callback 'operator()' return type must be void or "
"ArborX::CallbackTreeTraversalControl");

static_assert(
!std::is_same<PredicateTag, NearestPredicateTag>{} ||
std::is_void<
Kokkos::detected_t<Experimental_CallbackArchetypeExpression,
Callback, Predicate, int>>{},
Callback, Predicate, Value>>{},
"Callback 'operator()' return type must be void");
}

Expand Down
32 changes: 24 additions & 8 deletions src/details/ArborX_DetailsCrsGraphWrapperImpl.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -51,7 +51,8 @@ struct SecondPassTag
{};

template <typename PassTag, typename Predicates, typename Callback,
typename OutputView, typename CountView, typename PermutedOffset>
typename OutputView, typename CountView, typename PermutedOffset,
bool Legacy>
struct InsertGenerator
{
Callback _callback;
Expand All @@ -63,6 +64,15 @@ struct InsertGenerator
using Access = AccessTraits<Predicates, PredicatesTag>;
using PredicateType = typename AccessTraitsHelper<Access>::type;

// Legacy callback wrapper
template <typename Value, bool B = Legacy,
typename Enable = std::enable_if_t<!B>>
KOKKOS_FUNCTION auto operator()(PredicateType const &predicate,
Value const &value) const
{
return (*this)(predicate, (int)value.index);
}

KOKKOS_FUNCTION auto operator()(PredicateType const &predicate,
int primitive_index) const
{
Expand Down Expand Up @@ -114,6 +124,9 @@ struct InsertGenerator
namespace CrsGraphWrapperImpl
{

template <typename Callback>
using LegacyTreeArchetypeExpression = typename Callback::legacy_tree;

template <typename ExecutionSpace, typename Tree, typename Predicates,
typename Callback, typename OutputView, typename OffsetView,
typename PermuteType>
Expand Down Expand Up @@ -143,6 +156,9 @@ void queryImpl(ExecutionSpace const &space, Tree const &tree,
using PermutedOffset = PermutedData<OffsetView, PermuteType>;
PermutedOffset permuted_offset = {offset, permute};

constexpr bool Legacy =
Kokkos::is_detected_v<LegacyTreeArchetypeExpression, Tree>;

Kokkos::Profiling::pushRegion(
"ArborX::CrsGraphWrapper::two_pass::first_pass");
bool underflow = false;
Expand All @@ -152,8 +168,8 @@ void queryImpl(ExecutionSpace const &space, Tree const &tree,
tree.query(
space, permuted_predicates,
InsertGenerator<FirstPassTag, PermutedPredicates, Callback, OutputView,
CountView, PermutedOffset>{callback, out, counts,
permuted_offset},
CountView, PermutedOffset, Legacy>{
callback, out, counts, permuted_offset},
ArborX::Experimental::TraversalPolicy().setPredicateSorting(false));

// Detecting overflow is a local operation that needs to be done for every
Expand Down Expand Up @@ -187,8 +203,8 @@ void queryImpl(ExecutionSpace const &space, Tree const &tree,
tree.query(
space, permuted_predicates,
InsertGenerator<FirstPassNoBufferOptimizationTag, PermutedPredicates,
Callback, OutputView, CountView, PermutedOffset>{
callback, out, counts, permuted_offset},
Callback, OutputView, CountView, PermutedOffset,
Legacy>{callback, out, counts, permuted_offset},
ArborX::Experimental::TraversalPolicy().setPredicateSorting(false));
// This may not be true, but it does not matter. As long as we have
// (n_results == 0) check before second pass, this value is not used.
Expand Down Expand Up @@ -249,8 +265,8 @@ void queryImpl(ExecutionSpace const &space, Tree const &tree,
tree.query(
space, permuted_predicates,
InsertGenerator<SecondPassTag, PermutedPredicates, Callback, OutputView,
CountView, PermutedOffset>{callback, out, counts,
permuted_offset},
CountView, PermutedOffset, Legacy>{
callback, out, counts, permuted_offset},
ArborX::Experimental::TraversalPolicy().setPredicateSorting(false));

Kokkos::Profiling::popRegion();
Expand Down
4 changes: 2 additions & 2 deletions src/details/ArborX_DetailsFDBSCAN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ struct CountUpToN
Kokkos::View<int *, MemorySpace> _counts;
int _n;

template <typename Query>
KOKKOS_FUNCTION auto operator()(Query const &query, int) const
template <typename Query, typename Value>
KOKKOS_FUNCTION auto operator()(Query const &query, Value const &) const
{
auto i = getData(query);
Kokkos::atomic_increment(&_counts(i));
Expand Down
10 changes: 6 additions & 4 deletions src/details/ArborX_DetailsFDBSCANDenseBox.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ struct CountUpToN_DenseBox
, _n(n)
{}

template <typename Query>
KOKKOS_FUNCTION auto operator()(Query const &query, int k) const
template <typename Query, typename Value>
KOKKOS_FUNCTION auto operator()(Query const &query, Value const &value) const
{
using Access = AccessTraits<Primitives, PrimitivesTag>;

int const k = value.index;
auto const i = getData(query);

bool const is_dense_cell = (k < _num_dense_cells);
Expand Down Expand Up @@ -126,11 +127,12 @@ struct FDBSCANDenseBoxCallback
, eps(eps_in)
{}

template <typename Query>
KOKKOS_FUNCTION auto operator()(Query const &query, int k) const
template <typename Query, typename Value>
KOKKOS_FUNCTION auto operator()(Query const &query, Value const &value) const
{
using Access = AccessTraits<Primitives, PrimitivesTag>;

int const k = value.index;
auto const i = ArborX::getData(query);

bool const is_border_point = !_is_core_point(i);
Expand Down
10 changes: 7 additions & 3 deletions src/details/ArborX_DetailsMutualReachabilityDistance.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -28,12 +28,16 @@ struct MaxDistance
{
Primitives _primitives;
Distances _distances;

using Access = AccessTraits<Primitives, PrimitivesTag>;
using memory_space = typename Access::memory_space;
using size_type = typename memory_space::size_type;
template <class Predicate>
KOKKOS_FUNCTION void operator()(Predicate const &predicate, size_type i) const

template <class Predicate, typename Value>
KOKKOS_FUNCTION void operator()(Predicate const &predicate,
Value const &value) const
{
size_type const i = value.index;
size_type const j = getData(predicate);
using KokkosExt::max;
auto const distance_ij =
Expand Down
29 changes: 16 additions & 13 deletions test/tstCompileOnlyCallbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ void test_callbacks_compile_only()
check_valid_callback(CallbackMissingTag{}, SpatialPredicates{}, v);
check_valid_callback(CallbackMissingTag{}, NearestPredicates{}, v);

check_valid_callback(CustomCallback{}, SpatialPredicates{});
check_valid_callback(CustomCallback{}, NearestPredicates{});
check_valid_callback<int>(CustomCallback{}, SpatialPredicates{});
check_valid_callback<int>(CustomCallback{}, NearestPredicates{});

// generic lambdas are supported if not using NVCC
#ifndef __NVCC__
Expand All @@ -116,28 +116,31 @@ void test_callbacks_compile_only()
auto const & /*out*/) {},
NearestPredicates{}, v);

check_valid_callback([](auto const & /*predicate*/, int /*primitive*/) {},
SpatialPredicates{});
check_valid_callback<int>(
[](auto const & /*predicate*/, int /*primitive*/) {},
SpatialPredicates{});

check_valid_callback([](auto const & /*predicate*/, int /*primitive*/) {},
NearestPredicates{});
check_valid_callback<int>(
[](auto const & /*predicate*/, int /*primitive*/) {},
NearestPredicates{});
#endif

// Uncomment to see error messages

// check_valid_callback(LegacyNearestPredicateCallback{}, NearestPredicates{},
// v);
// v);

// check_valid_callback(CallbackDoesNotTakeCorrectArgument{},
// SpatialPredicates{}, v);
// SpatialPredicates{}, v);

// check_valid_callback(CustomCallbackNonVoidReturnType{},
// SpatialPredicates{});
// check_valid_callback<int>(CustomCallbackNonVoidReturnType{},
// SpatialPredicates{});

// check_valid_callback(CustomCallbackMissingConstQualifier{},
// SpatialPredicates{});
// check_valid_callback<int>(CustomCallbackMissingConstQualifier{},
// SpatialPredicates{});

#ifndef __NVCC__
// check_valid_callback([](Wrong, int /*primitive*/) {}, SpatialPredicates{});
// check_valid_callback<int>([](Wrong, int /*primitive*/) {},
// SpatialPredicates{});
#endif
}
8 changes: 4 additions & 4 deletions test/tstCompileOnlyTypeRequirements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ KOKKOS_FUNCTION float distance(FakePredicateGeometry, FakeBoundingVolume) { retu

struct PoorManLambda
{
template <class Predicate>
KOKKOS_FUNCTION void operator()(Predicate, int) const
template <class Predicate, typename Value>
KOKKOS_FUNCTION void operator()(Predicate, Value) const
{}
};
} // namespace Test
Expand Down Expand Up @@ -71,7 +71,7 @@ void check_bounding_volume_and_predicate_geometry_type_requirements()
tree.query(ExecutionSpace{}, spatial_predicates, Test::PoorManLambda{});
#ifndef __NVCC__
tree.query(ExecutionSpace{}, spatial_predicates,
KOKKOS_LAMBDA(SpatialPredicate, int){});
KOKKOS_LAMBDA(SpatialPredicate, auto){});
#endif

using NearestPredicate =
Expand All @@ -81,6 +81,6 @@ void check_bounding_volume_and_predicate_geometry_type_requirements()
tree.query(ExecutionSpace{}, nearest_predicates, Test::PoorManLambda{});
#ifndef __NVCC__
tree.query(ExecutionSpace{}, nearest_predicates,
KOKKOS_LAMBDA(NearestPredicate, int){});
KOKKOS_LAMBDA(NearestPredicate, auto){});
#endif
}
6 changes: 4 additions & 2 deletions test/tstDetailsMutualReachabilityDistance.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -32,7 +32,9 @@ auto compute_core_distances(ExecutionSpace exec_space,

ARBORX_ASSERT(points.extent_int(0) >= k);
using MemorySpace = typename ExecutionSpace::memory_space;
ArborX::BVH<MemorySpace> bvh{exec_space, points};
ArborX::BasicBoundingVolumeHierarchy<
MemorySpace, ArborX::Details::PairIndexVolume<ArborX::Box>>
bvh{exec_space, points};
Kokkos::View<float *, MemorySpace> distances(
Kokkos::view_alloc(Kokkos::WithoutInitializing, "Test::core_distances"),
bvh.size());
Expand Down
Loading

0 comments on commit 95bd0c0

Please sign in to comment.