Skip to content

Commit

Permalink
Introduce LegacyCallbackWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed May 18, 2023
1 parent d99b6cd commit 13c835a
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 18 deletions.
9 changes: 6 additions & 3 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BasicBoundingVolumeHierarchy
static_assert(Kokkos::is_memory_space<MemorySpace>::value);
using size_type = typename MemorySpace::size_type;
using bounding_volume_type = BoundingVolume;
using value_type = Details::PairIndexVolume<bounding_volume_type>;

BasicBoundingVolumeHierarchy() = default; // build an empty tree

Expand Down Expand Up @@ -85,7 +86,6 @@ class BasicBoundingVolumeHierarchy
private:
friend struct Details::HappyTreeFriends;

using value_type = Details::PairIndexVolume<bounding_volume_type>;
using leaf_node_type = Details::LeafNode<value_type>;
using internal_node_type = Details::InternalNode<bounding_volume_type>;

Expand Down Expand Up @@ -213,7 +213,8 @@ template <typename MemorySpace, typename BoundingVolume, typename Enable>
template <typename ExecutionSpace, typename Predicates, typename Callback>
void BasicBoundingVolumeHierarchy<MemorySpace, BoundingVolume, Enable>::query(
ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Experimental::TraversalPolicy const &policy) const
Callback const &legacy_callback,
Experimental::TraversalPolicy const &policy) const
{
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);
Expand All @@ -222,7 +223,9 @@ void BasicBoundingVolumeHierarchy<MemorySpace, BoundingVolume, Enable>::query(
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
Details::check_valid_callback(callback, predicates);
Details::check_valid_callback(legacy_callback, predicates);
Details::LegacyCallbackWrapper<Callback, value_type> callback{
legacy_callback};

using Tag = typename Details::AccessTraitsHelper<Access>::tag;
std::string profiling_prefix = "ArborX::BVH::query::";
Expand Down
13 changes: 13 additions & 0 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,19 @@ void check_valid_callback(Callback const &callback, Predicates const &)
"Callback 'operator()' return type must be void");
}

template <typename Callback, typename Value>
struct LegacyCallbackWrapper
{
Callback _callback;

template <typename Predicate>
KOKKOS_FUNCTION auto operator()(Predicate const &predicate,
Value const &value) const
{
return _callback(predicate, value.index);
}
};

} // namespace Details
} // namespace ArborX

Expand Down
6 changes: 3 additions & 3 deletions src/details/ArborX_DetailsHalfTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef ARBORX_DETAILS_HALF_TRAVERSAL_HPP
#define ARBORX_DETAILS_HALF_TRAVERSAL_HPP

#include <ArborX_Callbacks.hpp> // LegacyCallbackWrapper
#include <ArborX_DetailsHappyTreeFriends.hpp>
#include <ArborX_DetailsNode.hpp> // ROPE_SENTINEL

Expand All @@ -25,7 +26,7 @@ struct HalfTraversal
{
BVH _bvh;
PredicateGetter _get_predicate;
Callback _callback;
LegacyCallbackWrapper<Callback, typename BVH::value_type> _callback;

template <class ExecutionSpace>
HalfTraversal(ExecutionSpace const &space, BVH const &bvh,
Expand Down Expand Up @@ -68,8 +69,7 @@ struct HalfTraversal
{
if (is_leaf)
{
_callback(leaf_permutation_i,
HappyTreeFriends::getValue(_bvh, node).index);
_callback(leaf_permutation_i, HappyTreeFriends::getValue(_bvh, node));
node = HappyTreeFriends::getRope(_bvh, node);
}
else
Expand Down
17 changes: 7 additions & 10 deletions src/details/ArborX_DetailsTreeTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
HappyTreeFriends::getLeafBoundingVolume(_bvh, root);
if (predicate(root_bounding_volume))
{
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0).index);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}
}

Expand All @@ -102,8 +102,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
if (is_leaf)
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getValue(_bvh, node).index))
_callback, predicate, HappyTreeFriends::getValue(_bvh, node)))
return;
node = HappyTreeFriends::getRope(_bvh, node);
}
Expand Down Expand Up @@ -218,7 +217,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
if (k < 1)
return;

_callback(predicate, HappyTreeFriends::getValue(_bvh, 0).index);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}

KOKKOS_FUNCTION void operator()(int queryIndex) const
Expand Down Expand Up @@ -378,9 +377,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
sortHeap(heap.data(), heap.data() + heap.size(), heap.valueComp());
for (decltype(heap.size()) i = 0; i < heap.size(); ++i)
{
_callback(
predicate,
HappyTreeFriends::getValue(_bvh, (heap.data() + i)->first).index);
_callback(predicate,
HappyTreeFriends::getValue(_bvh, (heap.data() + i)->first));
}
}
};
Expand Down Expand Up @@ -440,7 +438,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
KokkosExt::ArithmeticTraits::infinity<distance_type>::value;
if (distance(getGeometry(predicate), root_bounding_volume) != inf)
{
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0).index);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}
}

Expand Down Expand Up @@ -487,8 +485,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
if (HappyTreeFriends::isLeaf(_bvh, node))
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getValue(_bvh, node).index))
_callback, predicate, HappyTreeFriends::getValue(_bvh, node)))
return;

if (heap.empty())
Expand Down
5 changes: 3 additions & 2 deletions src/details/ArborX_DetailsTreeVisualization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ struct TreeVisualization
struct VisitorCallback
{
template <typename Query>
KOKKOS_FUNCTION void operator()(Query const &, int index) const
KOKKOS_FUNCTION void
operator()(Query const &, typename TreeType::value_type const &value) const
{
_visitor.visit(_tree, permute(index));
_visitor.visit(_tree, permute(value.index));
}

TreeType _tree;
Expand Down

0 comments on commit 13c835a

Please sign in to comment.