Skip to content

Commit

Permalink
Merge pull request #1051 from aprokop/simplify_apiv2_in_apiv1_mode
Browse files Browse the repository at this point in the history
Simplify using APIv1 through APIv2 interface
  • Loading branch information
aprokop committed Apr 11, 2024
2 parents 4dc7cbe + 9a32bf4 commit 5cae9f6
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 25 deletions.
6 changes: 2 additions & 4 deletions benchmarks/brute_force_vs_bvh/brute_force_vs_bvh_timpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ static void run_fp(int nprimitives, int nqueries, int nrepeats)

Kokkos::View<int *, ExecutionSpace> indices("Benchmark::indices_ref", 0);
Kokkos::View<int *, ExecutionSpace> offset("Benchmark::offset_ref", 0);
bvh.query(space, predicates, ArborX::Details::LegacyDefaultCallback{},
indices, offset);
bvh.query(space, predicates, indices, offset);

space.fence();
double time = timer.seconds();
Expand All @@ -114,8 +113,7 @@ static void run_fp(int nprimitives, int nqueries, int nrepeats)

Kokkos::View<int *, ExecutionSpace> indices("Benchmark::indices", 0);
Kokkos::View<int *, ExecutionSpace> offset("Benchmark::offset", 0);
brute.query(space, predicates, ArborX::Details::LegacyDefaultCallback{},
indices, offset);
brute.query(space, predicates, indices, offset);

space.fence();
double time = timer.seconds();
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/dbscan/ArborX_DBSCANVerification.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,7 @@ bool verifyDBSCAN(ExecutionSpace exec_space, Primitives const &primitives,

Kokkos::View<int *, MemorySpace> indices("ArborX::DBSCAN::indices", 0);
Kokkos::View<int *, MemorySpace> offset("ArborX::DBSCAN::offset", 0);
ArborX::query(bvh, exec_space, predicates,
ArborX::Details::LegacyDefaultCallback{}, indices, offset);
ArborX::query(bvh, exec_space, predicates, indices, offset);

auto passed = Details::verifyClusters(exec_space, indices, offset, labels,
core_min_size);
Expand Down
26 changes: 15 additions & 11 deletions examples/simple_intersection/example_intersection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ int main(int argc, char *argv[])
using ExecutionSpace = Kokkos::DefaultExecutionSpace;
using MemorySpace = ExecutionSpace::memory_space;

Kokkos::View<ArborX::Box *, MemorySpace> boxes("Example::boxes", 4);
using Box = ArborX::ExperimentalHyperGeometry::Box<2>;
using Point = ArborX::ExperimentalHyperGeometry::Point<2>;

Kokkos::View<Box *, MemorySpace> boxes("Example::boxes", 4);
auto boxes_host = Kokkos::create_mirror_view(boxes);
boxes_host[0] = {{0, 0, 0}, {1, 1, 1}};
boxes_host[1] = {{1, 0, 0}, {2, 1, 1}};
boxes_host[2] = {{0, 1, 0}, {1, 2, 1}};
boxes_host[3] = {{1, 1, 0}, {2, 2, 1}};
boxes_host[0] = {{0, 0}, {1, 1}};
boxes_host[1] = {{1, 0}, {2, 1}};
boxes_host[2] = {{0, 1}, {1, 2}};
boxes_host[3] = {{1, 1}, {2, 2}};
Kokkos::deep_copy(boxes, boxes_host);

// -----------
Expand All @@ -37,17 +40,18 @@ int main(int argc, char *argv[])
// | | |
// | | |
// -----------
Kokkos::View<decltype(ArborX::intersects(ArborX::Point())) *, MemorySpace>
queries("Example::queries", 3);
Kokkos::View<decltype(ArborX::intersects(Point{})) *, MemorySpace> queries(
"Example::queries", 3);
auto queries_host = Kokkos::create_mirror_view(queries);
queries_host[0] = ArborX::intersects(ArborX::Point{1.8, 1.5, 0.5});
queries_host[1] = ArborX::intersects(ArborX::Point{1.3, 1.7, 0.5});
queries_host[2] = ArborX::intersects(ArborX::Point{1, 1, 0.5});
queries_host[0] = ArborX::intersects(Point{1.8, 1.5});
queries_host[1] = ArborX::intersects(Point{1.3, 1.7});
queries_host[2] = ArborX::intersects(Point{1, 1});
Kokkos::deep_copy(queries, queries_host);

ExecutionSpace space;

ArborX::BVH<MemorySpace> const tree(space, boxes);
ArborX::BVH<MemorySpace, ArborX::PairValueIndex<Box>> const tree(
space, ArborX::Experimental::attach_indices(boxes));

// The query will resize indices and offsets accordingly
Kokkos::View<int *, MemorySpace> indices("Example::indices", 0);
Expand Down
35 changes: 31 additions & 4 deletions src/ArborX_BruteForce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,37 @@ class BruteForce
using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Tag = typename Predicates::value_type::Tag;

Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, Predicates{user_predicates},
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
// Automatically add LegacyDefaultCallback if
// 1. A user does not provide a callback
// 2. The index is constructed on PairValueIndex
// 3. The output value_type is an integral type
constexpr bool use_convenient_shortcut = []() {
if constexpr (!Kokkos::is_view_v<std::decay_t<CallbackOrView>>)
return false;
else if constexpr (!Details::is_pair_value_index_v<value_type>)
return false;
else
return std::is_integral_v<
typename std::decay_t<CallbackOrView>::value_type>;
}();

if constexpr (use_convenient_shortcut)
{
// Simplified way to get APIv1 result using APIv2 interface
Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, Predicates{user_predicates},
Details::LegacyDefaultCallback{}, // inject legacy callback arg
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
return;
}
else
{
Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, Predicates{user_predicates},
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}
}

KOKKOS_FUNCTION auto const &indexable_get() const
Expand Down
35 changes: 31 additions & 4 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,37 @@ class BoundingVolumeHierarchy
using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Tag = typename Predicates::value_type::Tag;

Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, Predicates{user_predicates},
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
// Automatically add LegacyDefaultCallback if
// 1. A user does not provide a callback
// 2. The index is constructed on PairValueIndex
// 3. The output value_type is an integral type
constexpr bool use_convenient_shortcut = []() {
if constexpr (!Kokkos::is_view_v<std::decay_t<CallbackOrView>>)
return false;
else if constexpr (!Details::is_pair_value_index_v<value_type>)
return false;
else
return std::is_integral_v<
typename std::decay_t<CallbackOrView>::value_type>;
}();

if constexpr (use_convenient_shortcut)
{
// Simplified way to get APIv1 result using APIv2 interface
Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, Predicates{user_predicates},
Details::LegacyDefaultCallback{}, // inject legacy callback arg
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
return;
}
else
{
Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, Predicates{user_predicates},
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}
}

template <typename Predicate, typename Callback>
Expand Down
15 changes: 15 additions & 0 deletions src/details/ArborX_PairValueIndex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ struct PairValueIndex
Index index;
};

namespace Details
{
template <typename T>
struct is_pair_value_index : public std::false_type
{};

template <typename Value, typename Index>
struct is_pair_value_index<PairValueIndex<Value, Index>> : public std::true_type
{};

template <typename T>
inline constexpr bool is_pair_value_index_v = is_pair_value_index<T>::value;

} // namespace Details

} // namespace ArborX

#endif

0 comments on commit 5cae9f6

Please sign in to comment.