Skip to content

Commit

Permalink
Merge pull request #965 from aprokop/callbacks-3-arg
Browse files Browse the repository at this point in the history
APIv2: change signature of the 3-argument callback
  • Loading branch information
aprokop committed Nov 3, 2023
2 parents 80fbb31 + 45ab513 commit 3080e6b
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 92 deletions.
6 changes: 4 additions & 2 deletions benchmarks/brute_force_vs_bvh/brute_force_vs_bvh_timpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ static void run_fp(int nprimitives, int nqueries, int nrepeats)

Kokkos::View<int *, ExecutionSpace> indices("Benchmark::indices_ref", 0);
Kokkos::View<int *, ExecutionSpace> offset("Benchmark::offset_ref", 0);
bvh.query(space, predicates, indices, offset);
bvh.query(space, predicates, ArborX::Details::LegacyDefaultCallback{},
indices, offset);

space.fence();
double time = timer.seconds();
Expand All @@ -117,7 +118,8 @@ static void run_fp(int nprimitives, int nqueries, int nrepeats)

Kokkos::View<int *, ExecutionSpace> indices("Benchmark::indices", 0);
Kokkos::View<int *, ExecutionSpace> offset("Benchmark::offset", 0);
brute.query(space, predicates, indices, offset);
brute.query(space, predicates, ArborX::Details::LegacyDefaultCallback{},
indices, offset);

space.fence();
double time = timer.seconds();
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/dbscan/ArborX_DBSCANVerification.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ bool verifyDBSCAN(ExecutionSpace exec_space, Primitives const &primitives,

Kokkos::View<int *, MemorySpace> indices("ArborX::DBSCAN::indices", 0);
Kokkos::View<int *, MemorySpace> offset("ArborX::DBSCAN::offset", 0);
ArborX::query(bvh, exec_space, predicates, indices, offset);
ArborX::query(bvh, exec_space, predicates,
ArborX::Details::LegacyDefaultCallback{}, indices, offset);

auto passed = Details::verifyClusters(exec_space, indices, offset, labels,
core_min_size);
Expand Down
50 changes: 40 additions & 10 deletions src/ArborX_BruteForce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class BasicBruteForce
KokkosExt::ScopedProfileRegion guard("ArborX::BruteForce::query_crs");

Details::CrsGraphWrapperImpl::
check_valid_callback_if_first_argument_is_not_a_view(callback_or_view,
predicates, view);
check_valid_callback_if_first_argument_is_not_a_view<value_type>(
callback_or_view, predicates, view);

using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
Expand Down Expand Up @@ -101,8 +101,6 @@ class BruteForce
Details::DefaultIndexableGetter, BoundingVolume>;

public:
using legacy_tree = void;

using bounding_volume_type = typename base_type::bounding_volume_type;

BruteForce() = default;
Expand All @@ -123,20 +121,52 @@ class BruteForce
void query(ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Ignore = Ignore()) const
{
Details::check_valid_callback<int>(callback, predicates);
base_type::query(space, predicates,
Details::LegacyCallbackWrapper<Callback>{callback});
}

template <typename ExecutionSpace, typename Predicates,
typename CallbackOrView, typename View, typename... Args>
template <typename ExecutionSpace, typename Predicates, typename View,
typename... Args>
std::enable_if_t<Kokkos::is_view_v<std::decay_t<View>>>
query(ExecutionSpace const &space, Predicates const &predicates,
CallbackOrView &&callback_or_view, View &&view, Args &&...args) const
query(ExecutionSpace const &space, Predicates const &predicates, View &&view,
Args &&...args) const
{
base_type::query(space, predicates,
std::forward<CallbackOrView>(callback_or_view),
base_type::query(space, predicates, Details::LegacyDefaultCallback{},
std::forward<View>(view), std::forward<Args>(args)...);
}

template <typename ExecutionSpace, typename Predicates, typename Callback,
typename OutputView, typename OffsetView, typename... Args>
std::enable_if_t<!Kokkos::is_view_v<std::decay_t<Callback>>>
query(ExecutionSpace const &space, Predicates const &predicates,
Callback &&callback, OutputView &&out, OffsetView &&offset,
Args &&...args) const
{
if constexpr (!Details::is_tagged_post_callback<
std::decay_t<Callback>>::value)
{
Details::check_valid_callback<int>(callback, predicates, out);
base_type::query(space, predicates,
Details::LegacyCallbackWrapper<std::decay_t<Callback>>{
std::forward<Callback>(callback)},
std::forward<OutputView>(out),
std::forward<OffsetView>(offset),
std::forward<Args>(args)...);
}
else
{
KokkosExt::ScopedProfileRegion guard("ArborX::BruteForce::query_crs");

Kokkos::View<int *, MemorySpace> indices(
"ArborX::CrsGraphWrapper::query::indices", 0);
base_type::query(space, predicates, Details::LegacyDefaultCallback{},
indices, std::forward<OffsetView>(offset),
std::forward<Args>(args)...);
callback(predicates, std::forward<OffsetView>(offset), indices,
std::forward<OutputView>(out));
}
}
};

template <typename MemorySpace, typename Value, typename IndexableGetter,
Expand Down
50 changes: 40 additions & 10 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ class BasicBoundingVolumeHierarchy
KokkosExt::ScopedProfileRegion guard("ArborX::BVH::query_crs");

Details::CrsGraphWrapperImpl::
check_valid_callback_if_first_argument_is_not_a_view(callback_or_view,
predicates, view);
check_valid_callback_if_first_argument_is_not_a_view<value_type>(
callback_or_view, predicates, view);

using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
Expand Down Expand Up @@ -146,8 +146,6 @@ class BoundingVolumeHierarchy
Details::DefaultIndexableGetter, Box>;

public:
using legacy_tree = void;

using bounding_volume_type = typename base_type::bounding_volume_type;

BoundingVolumeHierarchy() = default; // build an empty tree
Expand All @@ -172,22 +170,54 @@ class BoundingVolumeHierarchy
Experimental::TraversalPolicy const &policy =
Experimental::TraversalPolicy()) const
{
Details::check_valid_callback<int>(callback, predicates);
base_type::query(space, predicates,
Details::LegacyCallbackWrapper<Callback>{callback},
policy);
}

template <typename ExecutionSpace, typename Predicates,
typename CallbackOrView, typename View, typename... Args>
template <typename ExecutionSpace, typename Predicates, typename View,
typename... Args>
std::enable_if_t<Kokkos::is_view_v<std::decay_t<View>>>
query(ExecutionSpace const &space, Predicates const &predicates,
CallbackOrView &&callback_or_view, View &&view, Args &&...args) const
query(ExecutionSpace const &space, Predicates const &predicates, View &&view,
Args &&...args) const
{
base_type::query(space, predicates,
std::forward<CallbackOrView>(callback_or_view),
base_type::query(space, predicates, Details::LegacyDefaultCallback{},
std::forward<View>(view), std::forward<Args>(args)...);
}

template <typename ExecutionSpace, typename Predicates, typename Callback,
typename OutputView, typename OffsetView, typename... Args>
std::enable_if_t<!Kokkos::is_view_v<std::decay_t<Callback>>>
query(ExecutionSpace const &space, Predicates const &predicates,
Callback &&callback, OutputView &&out, OffsetView &&offset,
Args &&...args) const
{
if constexpr (!Details::is_tagged_post_callback<
std::decay_t<Callback>>::value)
{
Details::check_valid_callback<int>(callback, predicates, out);
base_type::query(space, predicates,
Details::LegacyCallbackWrapper<std::decay_t<Callback>>{
std::forward<Callback>(callback)},
std::forward<OutputView>(out),
std::forward<OffsetView>(offset),
std::forward<Args>(args)...);
}
else
{
KokkosExt::ScopedProfileRegion guard("ArborX::BVH::query_crs");

Kokkos::View<int *, MemorySpace> indices(
"ArborX::CrsGraphWrapper::query::indices", 0);
base_type::query(space, predicates, Details::LegacyDefaultCallback{},
indices, std::forward<OffsetView>(offset),
std::forward<Args>(args)...);
callback(predicates, std::forward<OffsetView>(offset), indices,
std::forward<OutputView>(out));
}
}

template <typename Predicate, typename Callback>
KOKKOS_FUNCTION void query(Experimental::PerThread tag,
Predicate const &predicate,
Expand Down
29 changes: 15 additions & 14 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ struct PostCallbackTag

struct DefaultCallback
{
template <typename Query, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Query const &, int index,
template <typename Query, typename Value, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Query const &, Value const &value,
OutputFunctor const &output) const
{
output(index);
output(value);
}
};

// archetypal expression for user callbacks
template <typename Callback, typename Predicate, typename Out>
template <typename Callback, typename Predicate, typename Value, typename Out>
using InlineCallbackArchetypeExpression =
std::invoke_result_t<Callback, Predicate, int, Out>;
std::invoke_result_t<Callback, Predicate, Value, Out>;

// legacy nearest predicate archetypal expression for user callbacks
template <typename Callback, typename Predicate, typename Out>
Expand Down Expand Up @@ -88,7 +88,8 @@ void check_generic_lambda_support(Callback const &)
#endif
}

template <typename Callback, typename Predicates, typename OutputView>
template <typename Value, typename Callback, typename Predicates,
typename OutputView>
void check_valid_callback(Callback const &callback, Predicates const &,
OutputView const &)
{
Expand All @@ -106,16 +107,16 @@ void check_valid_callback(Callback const &callback, Predicates const &,
See https://github.com/arborx/ArborX/pull/366 for more details.
Sorry!)error");

static_assert(
is_valid_predicate_tag<PredicateTag>::value &&
Kokkos::is_detected<InlineCallbackArchetypeExpression, Callback,
Predicate, OutputFunctorHelper<OutputView>>{},
"Callback 'operator()' does not have the correct signature");
static_assert(is_valid_predicate_tag<PredicateTag>::value &&
Kokkos::is_detected<InlineCallbackArchetypeExpression,
Callback, Predicate, Value,
OutputFunctorHelper<OutputView>>{},
"Callback 'operator()' does not have the correct signature");

static_assert(
std::is_void<
Kokkos::detected_t<InlineCallbackArchetypeExpression, Callback,
Predicate, OutputFunctorHelper<OutputView>>>{},
std::is_void<Kokkos::detected_t<InlineCallbackArchetypeExpression,
Callback, Predicate, Value,
OutputFunctorHelper<OutputView>>>{},
"Callback 'operator()' return type must be void");
}

Expand Down
Loading

0 comments on commit 3080e6b

Please sign in to comment.