Skip to content

Commit

Permalink
Merge pull request #1001 from dalg24/modernize_callback_detection
Browse files Browse the repository at this point in the history
Modernize and fix callback detection
  • Loading branch information
aprokop committed Dec 28, 2023
2 parents 1dc256d + c697fad commit fffb2e0
Showing 1 changed file with 45 additions and 84 deletions.
129 changes: 45 additions & 84 deletions src/details/ArborX_Callbacks.hpp
Expand Up @@ -39,32 +39,22 @@ struct PostCallbackTag

struct DefaultCallback
{
template <typename Query, typename Value, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Query const &, Value const &value,
OutputFunctor const &output) const
template <typename Predicate, typename Value, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate const &, Value const &value,
OutputFunctor const &out) const
{
output(value);
out(value);
}
};

// archetypal expression for user callbacks
template <typename Callback, typename Predicate, typename Value, typename Out>
using InlineCallbackArchetypeExpression =
std::invoke_result_t<Callback, Predicate, Value, Out>;

// legacy nearest predicate archetypal expression for user callbacks
template <typename Callback, typename Predicate, typename Out>
using Legacy_NearestPredicateInlineCallbackArchetypeExpression =
std::invoke_result_t<Callback, Predicate, int, float, Out>;

// archetypal alias for a 'tag' type member in user callbacks
template <typename Callback>
using CallbackTagArchetypeAlias = typename Callback::tag;

template <typename Callback>
struct is_tagged_post_callback
: std::is_same<Kokkos::detected_t<CallbackTagArchetypeAlias, Callback>,
PostCallbackTag>::type
: Kokkos::is_detected_exact<PostCallbackTag, CallbackTagArchetypeAlias,
Callback>::type
{};

// output functor to pass to the callback during detection
Expand Down Expand Up @@ -99,72 +89,47 @@ void check_valid_callback(Callback const &callback, Predicates const &,
using PredicateTag = typename AccessTraitsHelper<Access>::tag;
using Predicate = typename AccessTraitsHelper<Access>::type;

static_assert(!(std::is_same<PredicateTag, NearestPredicateTag>{} &&
Kokkos::is_detected<
Legacy_NearestPredicateInlineCallbackArchetypeExpression,
Callback, Predicate, OutputFunctorHelper<OutputView>>{}),
static_assert(!(std::is_same_v<PredicateTag, NearestPredicateTag> &&
std::is_invocable_v<Callback const &, Predicate, int, float,
OutputFunctorHelper<OutputView>>),
R"error(Callback signature has changed for nearest predicates.
See https://github.com/arborx/ArborX/pull/366 for more details.
Sorry!)error");

static_assert(is_valid_predicate_tag<PredicateTag>::value &&
Kokkos::is_detected<InlineCallbackArchetypeExpression,
Callback, Predicate, Value,
OutputFunctorHelper<OutputView>>{},
std::is_invocable_v<Callback const &, Predicate, Value,
OutputFunctorHelper<OutputView>>,
"Callback 'operator()' does not have the correct signature");

static_assert(
std::is_void<Kokkos::detected_t<InlineCallbackArchetypeExpression,
Callback, Predicate, Value,
OutputFunctorHelper<OutputView>>>{},
std::is_void_v<std::invoke_result_t<Callback const &, Predicate, Value,
OutputFunctorHelper<OutputView>>>,
"Callback 'operator()' return type must be void");
}

// EXPERIMENTAL archetypal expression for user callbacks
template <typename Callback, typename Predicate, typename Primitive>
using Experimental_CallbackArchetypeExpression =
std::invoke_result_t<Callback, Predicate, Primitive>;

// Determine whether the callback returns a hint to exit the tree traversal
// early.
template <typename Callback, typename Predicate, typename Primitive>
struct invoke_callback_and_check_early_exit_helper
: std::is_same<CallbackTreeTraversalControl,
Kokkos::detected_t<Experimental_CallbackArchetypeExpression,
Callback, Predicate, Primitive>>::type
{};

// Invoke a callback that may return a hint to interrupt the tree traversal and
// return true for early exit, or false for normal continuation.
template <typename Callback, typename Predicate, typename Primitive>
KOKKOS_INLINE_FUNCTION
std::enable_if_t<invoke_callback_and_check_early_exit_helper<
std::decay_t<Callback>, std::decay_t<Predicate>,
std::decay_t<Primitive>>::value,
bool>
invoke_callback_and_check_early_exit(Callback &&callback,
Predicate &&predicate,
Primitive &&primitive)
KOKKOS_FUNCTION bool invoke_callback_and_check_early_exit(Callback &&callback,
Predicate &&predicate,
Primitive &&primitive)
{
return ((Callback &&) callback)((Predicate &&) predicate,
(Primitive &&) primitive) ==
CallbackTreeTraversalControl::early_exit;
}

// Invoke a callback that does not return a hint. Always return false to
// signify that the tree traversal should continue normally.
template <typename Callback, typename Predicate, typename Primitive>
KOKKOS_INLINE_FUNCTION
std::enable_if_t<!invoke_callback_and_check_early_exit_helper<
std::decay_t<Callback>, std::decay_t<Predicate>,
std::decay_t<Primitive>>::value,
bool>
invoke_callback_and_check_early_exit(Callback &&callback,
Predicate &&predicate,
Primitive &&primitive)
{
((Callback &&) callback)((Predicate &&) predicate, (Primitive &&) primitive);
return false;
if constexpr (std::is_same_v<CallbackTreeTraversalControl,
std::invoke_result_t<Callback &&, Predicate &&,
Primitive &&>>)
{
// Invoke a callback that may return a hint to interrupt the tree traversal
// and return true for early exit, or false for normal continuation.
return ((Callback &&) callback)((Predicate &&) predicate,
(Primitive &&) primitive) ==
CallbackTreeTraversalControl::early_exit;
}
else
{
// Invoke a callback that does not return a hint. Always return false to
// signify that the tree traversal should continue normally.
((Callback &&) callback)((Predicate &&) predicate,
(Primitive &&) primitive);
return false;
}
}

template <typename Value, typename Callback, typename Predicates>
Expand All @@ -179,29 +144,25 @@ void check_valid_callback(Callback const &callback, Predicates const &)
static_assert(is_valid_predicate_tag<PredicateTag>::value,
"The predicate tag is not valid");

static_assert(Kokkos::is_detected<Experimental_CallbackArchetypeExpression,
Callback, Predicate, Value>{},
static_assert(std::is_invocable_v<Callback const &, Predicate, Value>,
"Callback 'operator()' does not have the correct signature");

static_assert(
!(std::is_same<PredicateTag, SpatialPredicateTag>{} ||
std::is_same<PredicateTag,
Experimental::OrderedSpatialPredicateTag>{}) ||
(std::is_same<
!(std::is_same_v<PredicateTag, SpatialPredicateTag> ||
std::is_same_v<PredicateTag,
Experimental::OrderedSpatialPredicateTag>) ||
(std::is_same_v<
CallbackTreeTraversalControl,
Kokkos::detected_t<Experimental_CallbackArchetypeExpression,
Callback, Predicate, Value>>{} ||
std::is_void<
Kokkos::detected_t<Experimental_CallbackArchetypeExpression,
Callback, Predicate, Value>>{}),
std::invoke_result_t<Callback const &, Predicate, Value>> ||
std::is_void_v<
std::invoke_result_t<Callback const &, Predicate, Value>>),
"Callback 'operator()' return type must be void or "
"ArborX::CallbackTreeTraversalControl");

static_assert(
!std::is_same<PredicateTag, NearestPredicateTag>{} ||
std::is_void<
Kokkos::detected_t<Experimental_CallbackArchetypeExpression,
Callback, Predicate, Value>>{},
!std::is_same_v<PredicateTag, NearestPredicateTag> ||
std::is_void_v<
std::invoke_result_t<Callback const &, Predicate, Value>>,
"Callback 'operator()' return type must be void");
}

Expand Down

0 comments on commit fffb2e0

Please sign in to comment.