Skip to content

Commit

Permalink
Separate HappyTreeFunctions for internal and leaf nodes (#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed May 18, 2023
1 parent c667048 commit bf0fb97
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/details/ArborX_DetailsDistributedTreeImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ struct CallbackWithDistance
// need to be fixed with a proper callback abstraction.
int const leaf_node_index = _rev_permute(index);
auto const &leaf_node_bounding_volume =
HappyTreeFriends::getBoundingVolume(_tree, leaf_node_index);
HappyTreeFriends::getLeafBoundingVolume(_tree, leaf_node_index);
out({index, distance(getGeometry(query), leaf_node_bounding_volume)});
}
};
Expand Down
19 changes: 12 additions & 7 deletions src/details/ArborX_DetailsHalfTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,30 @@ struct HalfTraversal
KOKKOS_FUNCTION void operator()(int i) const
{
auto const predicate =
_get_predicate(HappyTreeFriends::getBoundingVolume(_bvh, i));
_get_predicate(HappyTreeFriends::getLeafBoundingVolume(_bvh, i));
auto const leaf_permutation_i =
HappyTreeFriends::getLeafPermutationIndex(_bvh, i);

int node = HappyTreeFriends::getRope(_bvh, i);
while (node != ROPE_SENTINEL)
{
if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node)))
bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node);

if (predicate(
(is_leaf
? HappyTreeFriends::getLeafBoundingVolume(_bvh, node)
: HappyTreeFriends::getInternalBoundingVolume(_bvh, node))))
{
if (!HappyTreeFriends::isLeaf(_bvh, node))
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
else
if (is_leaf)
{
_callback(leaf_permutation_i,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node));
node = HappyTreeFriends::getRope(_bvh, node);
}
else
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
}
else
{
Expand Down
39 changes: 26 additions & 13 deletions src/details/ArborX_DetailsHappyTreeFriends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
#include <type_traits>
#include <utility> // declval

namespace ArborX
{
namespace Details
namespace ArborX::Details
{

struct HappyTreeFriends
{
template <class BVH>
Expand All @@ -47,16 +46,32 @@ struct HappyTreeFriends
}

template <class BVH>
static KOKKOS_FUNCTION
// FIXME_HIP See https://github.com/arborx/ArborX/issues/553
#ifdef __HIP_DEVICE_COMPILE__
auto
#else
auto const &
#endif
getInternalBoundingVolume(BVH const &bvh, int i)
{
return bvh._internal_nodes(internalIndex(bvh, i)).bounding_volume;
}

template <class BVH>
static KOKKOS_FUNCTION
// FIXME_HIP See https://github.com/arborx/ArborX/issues/553
#ifdef __HIP_DEVICE_COMPILE__
static KOKKOS_FUNCTION auto getBoundingVolume(BVH const &bvh, int i)
auto
#else
static KOKKOS_FUNCTION auto const &getBoundingVolume(BVH const &bvh, int i)
auto const &
#endif
getLeafBoundingVolume(BVH const &bvh, int i)
{
auto const internal_i = internalIndex(bvh, i);
return (internal_i >= 0 ? bvh._internal_nodes(internal_i).bounding_volume
: bvh._leaf_nodes(i).bounding_volume);
static_assert(
std::is_same_v<decltype(bvh._internal_nodes(0).bounding_volume),
decltype(bvh._leaf_nodes(0).bounding_volume)>);
return bvh._leaf_nodes(i).bounding_volume;
}

template <class BVH>
Expand All @@ -83,12 +98,10 @@ struct HappyTreeFriends
template <class BVH>
static KOKKOS_FUNCTION auto getRope(BVH const &bvh, int i)
{
auto const internal_i = internalIndex(bvh, i);
return (internal_i >= 0 ? bvh._internal_nodes(internal_i).rope
: bvh._leaf_nodes(i).rope);
return (isLeaf(bvh, i) ? bvh._leaf_nodes(i).rope
: bvh._internal_nodes(internalIndex(bvh, i)).rope);
}
};
} // namespace Details
} // namespace ArborX
} // namespace ArborX::Details

#endif
56 changes: 36 additions & 20 deletions src/details/ArborX_DetailsTreeTraversal.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -78,7 +78,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
auto const &predicate = Access::get(_predicates, queryIndex);
auto const root = 0;
auto const &root_bounding_volume =
HappyTreeFriends::getBoundingVolume(_bvh, root);
HappyTreeFriends::getLeafBoundingVolume(_bvh, root);
if (predicate(root_bounding_volume))
{
_callback(predicate, 0);
Expand All @@ -92,20 +92,25 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
int node = HappyTreeFriends::getRoot(_bvh); // start with root
do
{
if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node)))
bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node);

if (predicate(
(is_leaf
? HappyTreeFriends::getLeafBoundingVolume(_bvh, node)
: HappyTreeFriends::getInternalBoundingVolume(_bvh, node))))
{
if (!HappyTreeFriends::isLeaf(_bvh, node))
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
else
if (is_leaf)
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node)))
return;
node = HappyTreeFriends::getRope(_bvh, node);
}
else
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
}
else
{
Expand Down Expand Up @@ -255,6 +260,14 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
heap(UnmanagedStaticVector<PairIndexDistance>(buffer.data(),
buffer.size()));

auto &bvh = _bvh;
auto const distance = [&predicate, &bvh](int j) {
return predicate.distance(
HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getLeafBoundingVolume(bvh, j)
: HappyTreeFriends::getInternalBoundingVolume(bvh, j));
};

constexpr int SENTINEL = -1;
int stack[64];
auto *stack_ptr = stack;
Expand Down Expand Up @@ -285,10 +298,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
left_child = HappyTreeFriends::getLeftChild(_bvh, node);
right_child = HappyTreeFriends::getRightChild(_bvh, node);

distance_left = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, left_child));
distance_right = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, right_child));
distance_left = distance(left_child);
distance_right = distance(right_child);

if (distance_left < radius)
{
Expand Down Expand Up @@ -337,8 +348,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
// This is a theoretically unnecessary duplication of distance
// calculation for stack nodes. However, for Cuda it's better than
// putting the distances in stack.
distance_node = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, node));
distance_node = distance(node);
}
#else
distance_node = *--stack_distance_ptr;
Expand Down Expand Up @@ -423,7 +433,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
auto const &predicate = Access::get(_predicates, queryIndex);
auto const root = 0;
auto const &root_bounding_volume =
HappyTreeFriends::getBoundingVolume(_bvh, root);
HappyTreeFriends::getLeafBoundingVolume(_bvh, root);
using distance_type =
decltype(distance(getGeometry(predicate), root_bounding_volume));
constexpr auto inf =
Expand All @@ -440,7 +450,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
using ArborX::Details::HappyTreeFriends;

using distance_type = decltype(predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, 0)));
HappyTreeFriends::getInternalBoundingVolume(_bvh, 0)));
using PairIndexDistance = Kokkos::pair<int, distance_type>;
struct CompareDistance
{
Expand All @@ -460,6 +470,14 @@ struct TreeTraversal<BVH, Predicates, Callback,
constexpr auto inf =
KokkosExt::ArithmeticTraits::infinity<distance_type>::value;

auto &bvh = _bvh;
auto const distance = [&predicate, &bvh](int j) {
return predicate.distance(
HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getLeafBoundingVolume(bvh, j)
: HappyTreeFriends::getInternalBoundingVolume(bvh, j));
};

int node = HappyTreeFriends::getRoot(_bvh);
int left_child;
int right_child;
Expand All @@ -484,12 +502,10 @@ struct TreeTraversal<BVH, Predicates, Callback,
left_child = HappyTreeFriends::getLeftChild(_bvh, node);
right_child = HappyTreeFriends::getRightChild(_bvh, node);

auto const distance_left = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, left_child));
auto const distance_left = distance(left_child);
auto const left_pair = Kokkos::make_pair(left_child, distance_left);

auto const distance_right = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, right_child));
auto const distance_right = distance(right_child);
auto const right_pair = Kokkos::make_pair(right_child, distance_right);

auto const &closer_pair =
Expand Down
6 changes: 4 additions & 2 deletions src/details/ArborX_DetailsTreeVisualization.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -121,7 +121,9 @@ struct TreeVisualization
auto const node_label = getNodeLabel(tree, node);
auto const node_attributes = getNodeAttributes(tree, node);
auto const bounding_volume =
HappyTreeFriends::getBoundingVolume(tree, node);
HappyTreeFriends::isLeaf(tree, node)
? HappyTreeFriends::getLeafBoundingVolume(tree, node)
: HappyTreeFriends::getInternalBoundingVolume(tree, node);
auto const min_corner = bounding_volume.minCorner();
auto const max_corner = bounding_volume.maxCorner();
_os << R"(\draw)" << node_attributes << " " << min_corner << " rectangle "
Expand Down
13 changes: 8 additions & 5 deletions src/details/ArborX_MinimumSpanningTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,14 @@ struct FindComponentNearestNeighbors
constexpr auto inf = KokkosExt::ArithmeticTraits::infinity<float>::value;

auto const distance = [bounding_volume_i =
HappyTreeFriends::getBoundingVolume(_bvh, i),
HappyTreeFriends::getLeafBoundingVolume(_bvh, i),
&bvh = _bvh](int j) {
using Details::distance;
return distance(bounding_volume_i,
HappyTreeFriends::getBoundingVolume(bvh, j));
auto &&bounding_volume_j =
(HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getLeafBoundingVolume(bvh, j)
: HappyTreeFriends::getInternalBoundingVolume(bvh, j));
return distance(bounding_volume_i, bounding_volume_j);
};

auto const component = _labels(i);
Expand Down Expand Up @@ -680,8 +683,8 @@ void resetSharedRadii(ExecutionSpace const &space, BVH const &bvh,
auto const r =
metric(HappyTreeFriends::getLeafPermutationIndex(bvh, i),
HappyTreeFriends::getLeafPermutationIndex(bvh, j),
distance(HappyTreeFriends::getBoundingVolume(bvh, i),
HappyTreeFriends::getBoundingVolume(bvh, j)));
distance(HappyTreeFriends::getLeafBoundingVolume(bvh, i),
HappyTreeFriends::getLeafBoundingVolume(bvh, j)));
Kokkos::atomic_min(&radii(label_i), r);
Kokkos::atomic_min(&radii(label_j), r);
}
Expand Down

0 comments on commit bf0fb97

Please sign in to comment.