Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed May 3, 2024
1 parent 1b4832b commit ab1eb5d
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/ArborX_DistributedTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class DistributedTree<MemorySpace, Details::LegacyDefaultTemplateValue,
MPI_Comm_rank(base_type::getComm(), &comm_rank);

base_type::query(space, predicates,
Details::LegacyDefaultCallbackWithRank{comm_rank},
Details::DefaultCallbackWithRank{comm_rank},
std::forward<IndicesAndRanks>(indices_and_ranks),
std::forward<OffsetView>(offset));
}
Expand Down
21 changes: 21 additions & 0 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#ifndef ARBORX_CALLBACKS_HPP
#define ARBORX_CALLBACKS_HPP

#include <ArborX_Config.hpp>

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Predicates.hpp> // is_valid_predicate_tag

Expand Down Expand Up @@ -44,6 +46,25 @@ struct DefaultCallback
}
};

#ifdef ARBORX_ENABLE_MPI
struct ConstrainedNearestCallbackTag
{};

struct DefaultCallbackWithRank
{
using tag = ConstrainedNearestCallbackTag;

int _rank;

template <typename Predicate, typename Value, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate const &, Value const &value,
OutputFunctor const &out) const
{
out({value, _rank});
}
};
#endif

// archetypal alias for a 'tag' type member in user callbacks
template <typename Callback>
using CallbackTagArchetypeAlias = typename Callback::tag;
Expand Down
24 changes: 12 additions & 12 deletions src/details/ArborX_DetailsDistributedTreeImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ struct DistributedTreeImpl
{
// spatial queries
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename IndicesAndRanks, typename Offset>
static std::enable_if_t<Kokkos::is_view_v<IndicesAndRanks> &&
typename Predicates, typename Values, typename Offset>
static std::enable_if_t<Kokkos::is_view_v<Values> &&
Kokkos::is_view_v<Offset>>
queryDispatch(SpatialPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries,
IndicesAndRanks &values, Offset &offset);
Values &values, Offset &offset);

template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename OutputView, typename OffsetView,
Expand All @@ -44,24 +44,24 @@ struct DistributedTreeImpl
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Callback, typename Indices,
typename Offset>
static std::enable_if_t<Kokkos::is_view_v<Indices> &&
Kokkos::is_view_v<Offset>>
queryDispatchImpl(NearestPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries,
Callback const &callback, Indices &indices, Offset &offset);
static void
queryDispatch2RoundImpl(NearestPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space,
Predicates const &queries, Callback const &callback,
Indices &indices, Offset &offset);

template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename IndicesAndRanks, typename Offset>
static std::enable_if_t<Kokkos::is_view_v<IndicesAndRanks> &&
typename Predicates, typename Values, typename Offset>
static std::enable_if_t<Kokkos::is_view_v<Values> &&
Kokkos::is_view_v<Offset>>
queryDispatch(NearestPredicateTag tag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries,
IndicesAndRanks &values, Offset &offset);
Values &values, Offset &offset);
template <typename Tree, typename ExecutionSpace, typename Predicates,
typename Callback, typename Values, typename Offset>
static std::enable_if_t<Kokkos::is_view_v<Values> &&
Kokkos::is_view_v<Offset>>
queryDispatch(NearestPredicateTag tag, Tree const &tree,
queryDispatch(NearestPredicateTag, Tree const &tree,
ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Values &values, Offset &offset);

Expand Down
22 changes: 12 additions & 10 deletions src/details/ArborX_DetailsDistributedTreeNearest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,10 @@ void DistributedTreeImpl::phaseII(ExecutionSpace const &space, Tree const &tree,

template <typename Tree, typename ExecutionSpace, typename Predicates,
typename Callback, typename Values, typename Offset>
std::enable_if_t<Kokkos::is_view_v<Values> && Kokkos::is_view_v<Offset>>
DistributedTreeImpl::queryDispatchImpl(NearestPredicateTag, Tree const &tree,
ExecutionSpace const &space,
Predicates const &predicates,
Callback const &callback, Values &values,
Offset &offset)
void DistributedTreeImpl::queryDispatch2RoundImpl(
NearestPredicateTag, Tree const &tree, ExecutionSpace const &space,
Predicates const &predicates, Callback const &callback, Values &values,
Offset &offset)
{
std::string prefix = "ArborX::DistributedTree::query::nearest";

Expand Down Expand Up @@ -222,20 +220,24 @@ DistributedTreeImpl::queryDispatch(NearestPredicateTag tag, Tree const &tree,
Predicates const &predicates, Values &values,
Offset &offset)
{
queryDispatchImpl(tag, tree, space, predicates, DefaultCallback{}, values,
offset);
queryDispatch2RoundImpl(tag, tree, space, predicates, DefaultCallback{},
values, offset);
}

template <typename Tree, typename ExecutionSpace, typename Predicates,
typename Callback, typename Values, typename Offset>
std::enable_if_t<Kokkos::is_view_v<Values> && Kokkos::is_view_v<Offset>>
DistributedTreeImpl::queryDispatch(NearestPredicateTag tag, Tree const &tree,
DistributedTreeImpl::queryDispatch(NearestPredicateTag, Tree const &tree,
ExecutionSpace const &space,
Predicates const &predicates,
Callback const &callback, Values &values,
Offset &offset)
{
queryDispatchImpl(tag, tree, space, predicates, callback, values, offset);
static_assert(Kokkos::is_detected_v<CallbackTagArchetypeAlias, Callback>);
static_assert(
std::is_same_v<typename Callback::tag, ConstrainedNearestCallbackTag>);
queryDispatch2RoundImpl(NearestPredicateTag{}, tree, space, predicates,
callback, values, offset);
}

} // namespace ArborX::Details
Expand Down
18 changes: 2 additions & 16 deletions src/details/ArborX_DetailsLegacy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct LegacyCallbackWrapper
PairValueIndex<Value, Index> const &value,
Output const &out) const
{
_callback(predicate, value.index, out);
_callback(predicate, (int)value.index, out);
}
};

Expand All @@ -96,24 +96,10 @@ struct LegacyDefaultCallback
PairValueIndex<Value, Index> const &value,
OutputFunctor const &output) const
{
output(value.index);
output((int)value.index);
}
};

#ifdef ARBORX_ENABLE_MPI
struct LegacyDefaultCallbackWithRank
{
int _rank;

template <typename Predicate, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate const &, int primitive_index,
OutputFunctor const &out) const
{
out({primitive_index, _rank});
}
};
#endif

struct LegacyDefaultTemplateValue
{};

Expand Down
2 changes: 2 additions & 0 deletions test/tstDistributedTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ struct PairRankIndex

struct DistributedNearestCallback
{
using tag = ArborX::Details::ConstrainedNearestCallbackTag;

int rank;

template <typename Predicate, typename Value, typename OutputFunctor>
Expand Down

0 comments on commit ab1eb5d

Please sign in to comment.