Skip to content

Commit

Permalink
Merge pull request #871 from aprokop/introduce_value
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed May 18, 2023
2 parents bf0fb97 + 13c835a commit f2e69c9
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 68 deletions.
12 changes: 8 additions & 4 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BasicBoundingVolumeHierarchy
static_assert(Kokkos::is_memory_space<MemorySpace>::value);
using size_type = typename MemorySpace::size_type;
using bounding_volume_type = BoundingVolume;
using value_type = Details::PairIndexVolume<bounding_volume_type>;

BasicBoundingVolumeHierarchy() = default; // build an empty tree

Expand Down Expand Up @@ -85,7 +86,7 @@ class BasicBoundingVolumeHierarchy
private:
friend struct Details::HappyTreeFriends;

using leaf_node_type = Details::LeafNode<bounding_volume_type>;
using leaf_node_type = Details::LeafNode<value_type>;
using internal_node_type = Details::InternalNode<bounding_volume_type>;

KOKKOS_FUNCTION
Expand All @@ -98,7 +99,7 @@ class BasicBoundingVolumeHierarchy
assert((n == 1 || Details::HappyTreeFriends::getRoot(*this) == n) &&
"workaround below assumes root is stored as first element");
return (n > 1 ? &_internal_nodes.data()->bounding_volume
: &_leaf_nodes.data()->bounding_volume);
: &_leaf_nodes.data()->value.bounding_volume);
}

size_type _size{0};
Expand Down Expand Up @@ -212,7 +213,8 @@ template <typename MemorySpace, typename BoundingVolume, typename Enable>
template <typename ExecutionSpace, typename Predicates, typename Callback>
void BasicBoundingVolumeHierarchy<MemorySpace, BoundingVolume, Enable>::query(
ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Experimental::TraversalPolicy const &policy) const
Callback const &legacy_callback,
Experimental::TraversalPolicy const &policy) const
{
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);
Expand All @@ -221,7 +223,9 @@ void BasicBoundingVolumeHierarchy<MemorySpace, BoundingVolume, Enable>::query(
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
Details::check_valid_callback(callback, predicates);
Details::check_valid_callback(legacy_callback, predicates);
Details::LegacyCallbackWrapper<Callback, value_type> callback{
legacy_callback};

using Tag = typename Details::AccessTraitsHelper<Access>::tag;
std::string profiling_prefix = "ArborX::BVH::query::";
Expand Down
13 changes: 13 additions & 0 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,19 @@ void check_valid_callback(Callback const &callback, Predicates const &)
"Callback 'operator()' return type must be void");
}

template <typename Callback, typename Value>
struct LegacyCallbackWrapper
{
Callback _callback;

template <typename Predicate>
KOKKOS_FUNCTION auto operator()(Predicate const &predicate,
Value const &value) const
{
return _callback(predicate, value.index);
}
};

} // namespace Details
} // namespace ArborX

Expand Down
3 changes: 1 addition & 2 deletions src/details/ArborX_DetailsDistributedTreeImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,7 @@ struct CallbackWithDistance
"compute_reverse_permutation",
Kokkos::RangePolicy<ExecutionSpace>(exec_space, 0, n),
ARBORX_CLASS_LAMBDA(int const i) {
_rev_permute(HappyTreeFriends::getLeafPermutationIndex(_tree, i)) =
i;
_rev_permute(HappyTreeFriends::getValue(_tree, i).index) = i;
});
}
}
Expand Down
9 changes: 4 additions & 5 deletions src/details/ArborX_DetailsHalfTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef ARBORX_DETAILS_HALF_TRAVERSAL_HPP
#define ARBORX_DETAILS_HALF_TRAVERSAL_HPP

#include <ArborX_Callbacks.hpp> // LegacyCallbackWrapper
#include <ArborX_DetailsHappyTreeFriends.hpp>
#include <ArborX_DetailsNode.hpp> // ROPE_SENTINEL

Expand All @@ -25,7 +26,7 @@ struct HalfTraversal
{
BVH _bvh;
PredicateGetter _get_predicate;
Callback _callback;
LegacyCallbackWrapper<Callback, typename BVH::value_type> _callback;

template <class ExecutionSpace>
HalfTraversal(ExecutionSpace const &space, BVH const &bvh,
Expand Down Expand Up @@ -54,8 +55,7 @@ struct HalfTraversal
{
auto const predicate =
_get_predicate(HappyTreeFriends::getLeafBoundingVolume(_bvh, i));
auto const leaf_permutation_i =
HappyTreeFriends::getLeafPermutationIndex(_bvh, i);
auto const leaf_permutation_i = HappyTreeFriends::getValue(_bvh, i).index;

int node = HappyTreeFriends::getRope(_bvh, i);
while (node != ROPE_SENTINEL)
Expand All @@ -69,8 +69,7 @@ struct HalfTraversal
{
if (is_leaf)
{
_callback(leaf_permutation_i,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node));
_callback(leaf_permutation_i, HappyTreeFriends::getValue(_bvh, node));
node = HappyTreeFriends::getRope(_bvh, node);
}
else
Expand Down
8 changes: 4 additions & 4 deletions src/details/ArborX_DetailsHappyTreeFriends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ struct HappyTreeFriends
{
static_assert(
std::is_same_v<decltype(bvh._internal_nodes(0).bounding_volume),
decltype(bvh._leaf_nodes(0).bounding_volume)>);
return bvh._leaf_nodes(i).bounding_volume;
decltype(bvh._leaf_nodes(0).value.bounding_volume)>);
return bvh._leaf_nodes(i).value.bounding_volume;
}

template <class BVH>
static KOKKOS_FUNCTION auto getLeafPermutationIndex(BVH const &bvh, int i)
static KOKKOS_FUNCTION auto const &getValue(BVH const &bvh, int i)
{
assert(i >= 0 && i < (int)bvh.size());
return bvh._leaf_nodes(i).permutation_index;
return bvh._leaf_nodes(i).value;
}

template <class BVH>
Expand Down
30 changes: 16 additions & 14 deletions src/details/ArborX_DetailsNode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@
#include <Kokkos_Macros.hpp>

#include <cassert>
#include <climits> // UINT_MAX
#include <utility> // std::move

namespace ArborX
{
namespace Details
namespace ArborX::Details
{

constexpr int ROPE_SENTINEL = -1;

template <class BoundingVolume>
struct PairIndexVolume
{
unsigned index;
BoundingVolume bounding_volume;
};

template <class Value>
struct LeafNode
{
using bounding_volume_type = BoundingVolume;
using value_type = Value;

unsigned permutation_index = UINT_MAX;
int rope = ROPE_SENTINEL;
BoundingVolume bounding_volume;
Value value;
};

template <class BoundingVolume>
Expand All @@ -48,14 +51,13 @@ struct InternalNode
BoundingVolume bounding_volume;
};

template <class BoundingVolume>
KOKKOS_INLINE_FUNCTION constexpr LeafNode<BoundingVolume>
makeLeafNode(unsigned permutation_index,
BoundingVolume bounding_volume) noexcept
template <class Value>
KOKKOS_INLINE_FUNCTION constexpr LeafNode<Value>
makeLeafNode(Value value) noexcept
{
return {permutation_index, ROPE_SENTINEL, std::move(bounding_volume)};
return {ROPE_SENTINEL, std::move(value)};
}
} // namespace Details
} // namespace ArborX

} // namespace ArborX::Details

#endif
14 changes: 7 additions & 7 deletions src/details/ArborX_DetailsTreeConstruction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,18 @@ inline void initializeSingleLeafNode(ExecutionSpace const &space,
Nodes const &leaf_nodes)
{
using Access = AccessTraits<Primitives, PrimitivesTag>;
using Value = typename Nodes::value_type::value_type;
using BoundingVolume = decltype(std::declval<Value>().bounding_volume);

ARBORX_ASSERT(leaf_nodes.extent(0) == 1);
ARBORX_ASSERT(Access::size(primitives) == 1);

using Node = typename Nodes::value_type;
using BoundingVolume = typename Node::bounding_volume_type;

Kokkos::parallel_for(
"ArborX::TreeConstruction::initialize_single_leaf",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, 1), KOKKOS_LAMBDA(int) {
BoundingVolume bounding_volume{};
expand(bounding_volume, Access::get(primitives, 0));
leaf_nodes(0) = makeLeafNode(0, std::move(bounding_volume));
leaf_nodes(0) = makeLeafNode(Value{(unsigned)0, bounding_volume});
});
}

Expand Down Expand Up @@ -209,8 +208,9 @@ class GenerateHierarchy
expand(bounding_volume, Access::get(_primitives, original_index));

// Initialize leaf node
using Value = typename LeafNodes::value_type::value_type;
auto &leaf_node = _leaf_nodes(i);
leaf_node = makeLeafNode(original_index, bounding_volume);
leaf_node = makeLeafNode(Value{original_index, bounding_volume});

// For a leaf node, the range is just one index
int range_left = i;
Expand Down Expand Up @@ -271,7 +271,7 @@ class GenerateHierarchy
Kokkos::load_fence();
expand(bounding_volume,
right_child_is_leaf
? _leaf_nodes(right_child).bounding_volume
? _leaf_nodes(right_child).value.bounding_volume
: _internal_nodes(right_child).bounding_volume);
}
else
Expand All @@ -294,7 +294,7 @@ class GenerateHierarchy
Kokkos::load_fence();
expand(bounding_volume,
left_child_is_leaf
? _leaf_nodes(left_child).bounding_volume
? _leaf_nodes(left_child).value.bounding_volume
: _internal_nodes(left_child).bounding_volume);

if (!left_child_is_leaf)
Expand Down
17 changes: 7 additions & 10 deletions src/details/ArborX_DetailsTreeTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
HappyTreeFriends::getLeafBoundingVolume(_bvh, root);
if (predicate(root_bounding_volume))
{
_callback(predicate, 0);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}
}

Expand All @@ -102,8 +102,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
if (is_leaf)
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node)))
_callback, predicate, HappyTreeFriends::getValue(_bvh, node)))
return;
node = HappyTreeFriends::getRope(_bvh, node);
}
Expand Down Expand Up @@ -218,7 +217,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
if (k < 1)
return;

_callback(predicate, 0);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}

KOKKOS_FUNCTION void operator()(int queryIndex) const
Expand Down Expand Up @@ -378,9 +377,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
sortHeap(heap.data(), heap.data() + heap.size(), heap.valueComp());
for (decltype(heap.size()) i = 0; i < heap.size(); ++i)
{
int const leaf_index = HappyTreeFriends::getLeafPermutationIndex(
_bvh, (heap.data() + i)->first);
_callback(predicate, leaf_index);
_callback(predicate,
HappyTreeFriends::getValue(_bvh, (heap.data() + i)->first));
}
}
};
Expand Down Expand Up @@ -440,7 +438,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
KokkosExt::ArithmeticTraits::infinity<distance_type>::value;
if (distance(getGeometry(predicate), root_bounding_volume) != inf)
{
_callback(predicate, 0);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}
}

Expand Down Expand Up @@ -487,8 +485,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
if (HappyTreeFriends::isLeaf(_bvh, node))
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node)))
_callback, predicate, HappyTreeFriends::getValue(_bvh, node)))
return;

if (heap.empty())
Expand Down
10 changes: 5 additions & 5 deletions src/details/ArborX_DetailsTreeVisualization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ struct TreeVisualization
{
auto const node_is_leaf = HappyTreeFriends::isLeaf(tree, node);
auto const node_index =
node_is_leaf ? HappyTreeFriends::getLeafPermutationIndex(tree, node)
: node;
node_is_leaf ? HappyTreeFriends::getValue(tree, node).index : node;
std::string label = node_is_leaf ? "l" : "i";
label.append(std::to_string(node_index));
return label;
Expand Down Expand Up @@ -154,9 +153,10 @@ struct TreeVisualization
struct VisitorCallback
{
template <typename Query>
KOKKOS_FUNCTION void operator()(Query const &, int index) const
KOKKOS_FUNCTION void
operator()(Query const &, typename TreeType::value_type const &value) const
{
_visitor.visit(_tree, permute(index));
_visitor.visit(_tree, permute(value.index));
}

TreeType _tree;
Expand Down Expand Up @@ -191,7 +191,7 @@ struct TreeVisualization
Kokkos::parallel_for(
"ArborX::Viz::compute_permutation",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n), KOKKOS_LAMBDA(int i) {
permute(HappyTreeFriends::getLeafPermutationIndex(tree, i)) = i;
permute(HappyTreeFriends::getValue(tree, i).index) = i;
});

Predicates predicates(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
Expand Down
Loading

0 comments on commit f2e69c9

Please sign in to comment.