forked from arborx/ArborX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ArborX_Callbacks.hpp
189 lines (157 loc) · 6.49 KB
/
ArborX_Callbacks.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
/****************************************************************************
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
* distributed under a BSD 3-clause license. For the licensing terms see *
* the LICENSE file in the top-level directory. *
* *
* SPDX-License-Identifier: BSD-3-Clause *
****************************************************************************/
#ifndef ARBORX_CALLBACKS_HPP
#define ARBORX_CALLBACKS_HPP
#include <ArborX_Config.hpp>
#include <ArborX_AccessTraits.hpp>
#include <ArborX_Predicates.hpp> // is_valid_predicate_tag
#include <Kokkos_DetectionIdiom.hpp>
#include <Kokkos_Macros.hpp>
#include <type_traits>
namespace ArborX
{
enum class CallbackTreeTraversalControl
{
early_exit,
normal_continuation
};
namespace Details
{
struct PostCallbackTag
{};
struct DefaultCallback
{
template <typename Predicate, typename Value, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate const &, Value const &value,
OutputFunctor const &out) const
{
out(value);
}
};
#ifdef ARBORX_ENABLE_MPI
struct ConstrainedNearestCallbackTag
{};
struct DefaultCallbackWithRank
{
using tag = ConstrainedNearestCallbackTag;
int _rank;
template <typename Predicate, typename Value, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate const &, Value const &value,
OutputFunctor const &out) const
{
out({value, _rank});
}
};
#endif
// archetypal alias for a 'tag' type member in user callbacks
template <typename Callback>
using CallbackTagArchetypeAlias = typename Callback::tag;
template <typename Callback>
struct is_tagged_post_callback
: Kokkos::is_detected_exact<PostCallbackTag, CallbackTagArchetypeAlias,
Callback>::type
{};
// output functor to pass to the callback during detection
template <typename T>
struct Sink
{
KOKKOS_FUNCTION void operator()(T const &) const {}
};
template <typename OutputView>
using OutputFunctorHelper = Sink<typename OutputView::value_type>;
template <class Callback>
void check_generic_lambda_support(Callback const &)
{
#ifdef __NVCC__
// Without it would get a segmentation fault and no diagnostic whatsoever
static_assert(
!__nv_is_extended_host_device_lambda_closure_type(Callback),
"__host__ __device__ extended lambdas cannot be generic lambdas");
#endif
}
template <typename Value, typename Callback, typename Predicates,
typename OutputView>
void check_valid_callback(Callback const &callback, Predicates const &,
OutputView const &)
{
check_generic_lambda_support(callback);
using Predicate =
typename AccessValues<Predicates, PredicatesTag>::value_type;
using PredicateTag = typename Predicate::Tag;
static_assert(!(std::is_same_v<PredicateTag, NearestPredicateTag> &&
std::is_invocable_v<Callback const &, Predicate, int, float,
OutputFunctorHelper<OutputView>>),
R"error(Callback signature has changed for nearest predicates.
See https://github.com/arborx/ArborX/pull/366 for more details.
Sorry!)error");
static_assert(is_valid_predicate_tag<PredicateTag>::value &&
std::is_invocable_v<Callback const &, Predicate, Value,
OutputFunctorHelper<OutputView>>,
"Callback 'operator()' does not have the correct signature");
static_assert(
std::is_void_v<std::invoke_result_t<Callback const &, Predicate, Value,
OutputFunctorHelper<OutputView>>>,
"Callback 'operator()' return type must be void");
}
template <typename Callback, typename Predicate, typename Primitive>
KOKKOS_FUNCTION bool invoke_callback_and_check_early_exit(Callback &&callback,
Predicate &&predicate,
Primitive &&primitive)
{
if constexpr (std::is_same_v<CallbackTreeTraversalControl,
std::invoke_result_t<Callback &&, Predicate &&,
Primitive &&>>)
{
// Invoke a callback that may return a hint to interrupt the tree traversal
// and return true for early exit, or false for normal continuation.
return ((Callback &&) callback)((Predicate &&) predicate,
(Primitive &&) primitive) ==
CallbackTreeTraversalControl::early_exit;
}
else
{
// Invoke a callback that does not return a hint. Always return false to
// signify that the tree traversal should continue normally.
((Callback &&) callback)((Predicate &&) predicate,
(Primitive &&) primitive);
return false;
}
}
template <typename Value, typename Callback, typename Predicates>
void check_valid_callback(Callback const &callback, Predicates const &)
{
check_generic_lambda_support(callback);
using Predicate =
typename AccessValues<Predicates, PredicatesTag>::value_type;
using PredicateTag = typename Predicate::Tag;
static_assert(is_valid_predicate_tag<PredicateTag>::value,
"The predicate tag is not valid");
static_assert(std::is_invocable_v<Callback const &, Predicate, Value>,
"Callback 'operator()' does not have the correct signature");
static_assert(
!(std::is_same_v<PredicateTag, SpatialPredicateTag> ||
std::is_same_v<PredicateTag, OrderedSpatialPredicateTag>) ||
(std::is_same_v<
CallbackTreeTraversalControl,
std::invoke_result_t<Callback const &, Predicate, Value>> ||
std::is_void_v<
std::invoke_result_t<Callback const &, Predicate, Value>>),
"Callback 'operator()' return type must be void or "
"ArborX::CallbackTreeTraversalControl");
static_assert(
!std::is_same_v<PredicateTag, NearestPredicateTag> ||
std::is_void_v<
std::invoke_result_t<Callback const &, Predicate, Value>>,
"Callback 'operator()' return type must be void");
}
} // namespace Details
} // namespace ArborX
#endif