Skip to content

Commit

Permalink
Better way for a user to indicate the intention
Browse files Browse the repository at this point in the history
The problem with tags is that their behavior is not visible for users.
If a previous version of ArborX ignores tags, it may just silently do
something else.

Instead, we want to fail compilation if the user indicates an intention
we do not support. This is one way to do it: user has to wrap their
callback into a class.
  • Loading branch information
aprokop committed May 9, 2024
1 parent ab1eb5d commit 94330e2
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 38 deletions.
21 changes: 0 additions & 21 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#ifndef ARBORX_CALLBACKS_HPP
#define ARBORX_CALLBACKS_HPP

#include <ArborX_Config.hpp>

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Predicates.hpp> // is_valid_predicate_tag

Expand Down Expand Up @@ -46,25 +44,6 @@ struct DefaultCallback
}
};

#ifdef ARBORX_ENABLE_MPI
struct ConstrainedNearestCallbackTag
{};

struct DefaultCallbackWithRank
{
using tag = ConstrainedNearestCallbackTag;

int _rank;

template <typename Predicate, typename Value, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate const &, Value const &value,
OutputFunctor const &out) const
{
out({value, _rank});
}
};
#endif

// archetypal alias for a 'tag' type member in user callbacks
template <typename Callback>
using CallbackTagArchetypeAlias = typename Callback::tag;
Expand Down
16 changes: 11 additions & 5 deletions src/details/ArborX_DetailsDistributedTreeNearest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ void DistributedTreeImpl::queryDispatch2RoundImpl(

Kokkos::Profiling::ScopedRegion guard(prefix);

static_assert(is_constrained_callback_v<Callback>);

if (tree.empty())
{
KokkosExt::reallocWithoutInitializing(space, values, 0);
Expand Down Expand Up @@ -233,11 +235,15 @@ DistributedTreeImpl::queryDispatch(NearestPredicateTag, Tree const &tree,
Callback const &callback, Values &values,
Offset &offset)
{
static_assert(Kokkos::is_detected_v<CallbackTagArchetypeAlias, Callback>);
static_assert(
std::is_same_v<typename Callback::tag, ConstrainedNearestCallbackTag>);
queryDispatch2RoundImpl(NearestPredicateTag{}, tree, space, predicates,
callback, values, offset);
if constexpr (is_constrained_callback_v<Callback>)
{
queryDispatch2RoundImpl(NearestPredicateTag{}, tree, space, predicates,
callback, values, offset);
}
else
{
Kokkos::abort("3-arg callback not implemented yet.");
}
}

} // namespace ArborX::Details
Expand Down
63 changes: 61 additions & 2 deletions src/details/ArborX_DetailsDistributedTreeNearestHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,68 @@

namespace ArborX
{

namespace Experimental
{

// Constrained callback is a callback that a user promises to:
// - be not pure
// - be allowed to be called on non-final results
// - produce exactly one result for each match
template <class Callback>
struct ConstrainedDistributedNearestCallback
{
Callback _callback;

template <class... Args>
KOKKOS_FUNCTION void operator()(Args &&...args) const
{
_callback((Args &&) args...);
}
};

template <class Callback>
auto declare_callback_constrained(Callback const &callback)
{
return ConstrainedDistributedNearestCallback<Callback>{callback};
}

} // namespace Experimental

namespace Details
{

struct DefaultCallbackWithRank
{
int _rank;

template <typename Predicate, typename Value, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate const &, Value const &value,
OutputFunctor const &out) const
{
out({value, _rank});
}
};

template <class Callback>
struct is_constrained_callback : std::false_type
{};
template <class Callback>
struct is_constrained_callback<
Experimental::ConstrainedDistributedNearestCallback<Callback>>
: std::true_type
{};
template <>
struct is_constrained_callback<DefaultCallback> : std::true_type
{};
template <>
struct is_constrained_callback<DefaultCallbackWithRank> : std::true_type
{};

template <class Callback>
inline constexpr bool is_constrained_callback_v =
is_constrained_callback<Callback>::value;

template <class Predicates, class Distances>
struct WithinDistanceFromPredicates
{
Expand Down Expand Up @@ -116,7 +175,7 @@ struct CallbackWithDistance
if constexpr (UseValues)
{
OutValue out_value;
int count = 0;
[[maybe_unused]] int count = 0;
_callback(query, value, [&](OutValue const &ov) {
out_value = ov;
++count;
Expand Down Expand Up @@ -193,7 +252,7 @@ struct CallbackWithDistance<
if constexpr (UseValues)
{
OutValue out_value;
int count = 0;
[[maybe_unused]] int count = 0;
_callback(query, index, [&](OutValue const &ov) {
out_value = ov;
++count;
Expand Down
21 changes: 11 additions & 10 deletions test/tstDistributedTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ struct PairRankIndex

struct DistributedNearestCallback
{
using tag = ArborX::Details::ConstrainedNearestCallbackTag;

int rank;

template <typename Predicate, typename Value, typename OutputFunctor>
Expand Down Expand Up @@ -167,19 +165,22 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(hello_world, DeviceType, ARBORX_DEVICE_TYPES)
// Now do the same with callbacks
if (comm_rank < comm_size - 1)
{
ARBORX_TEST_QUERY_TREE_CALLBACK(ExecutionSpace{}, tree, nearest_queries,
DistributedNearestCallback{comm_rank},
make_reference_solution<PairRankIndex>(
{{comm_size - 1 - comm_rank, 0},
{comm_size - 2 - comm_rank, n - 1},
{comm_size - 1 - comm_rank, 1}},
{0, 3}));
ARBORX_TEST_QUERY_TREE_CALLBACK(
ExecutionSpace{}, tree, nearest_queries,
ArborX::Experimental::declare_callback_constrained(
DistributedNearestCallback{comm_rank}),
make_reference_solution<PairRankIndex>(
{{comm_size - 1 - comm_rank, 0},
{comm_size - 2 - comm_rank, n - 1},
{comm_size - 1 - comm_rank, 1}},
{0, 3}));
}
else
{
ARBORX_TEST_QUERY_TREE_CALLBACK(
ExecutionSpace{}, tree, nearest_queries,
DistributedNearestCallback{comm_rank},
ArborX::Experimental::declare_callback_constrained(
DistributedNearestCallback{comm_rank}),
make_reference_solution<PairRankIndex>(
{{comm_size - 1 - comm_rank, 0}, {comm_size - 1 - comm_rank, 1}},
{0, 2}));
Expand Down

0 comments on commit 94330e2

Please sign in to comment.