Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Value as a generic storage in a leaf node #871

Merged
merged 2 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BasicBoundingVolumeHierarchy
static_assert(Kokkos::is_memory_space<MemorySpace>::value);
using size_type = typename MemorySpace::size_type;
using bounding_volume_type = BoundingVolume;
using value_type = Details::PairIndexVolume<bounding_volume_type>;

BasicBoundingVolumeHierarchy() = default; // build an empty tree

Expand Down Expand Up @@ -85,7 +86,7 @@ class BasicBoundingVolumeHierarchy
private:
friend struct Details::HappyTreeFriends;

using leaf_node_type = Details::LeafNode<bounding_volume_type>;
using leaf_node_type = Details::LeafNode<value_type>;
using internal_node_type = Details::InternalNode<bounding_volume_type>;

KOKKOS_FUNCTION
Expand All @@ -98,7 +99,7 @@ class BasicBoundingVolumeHierarchy
assert((n == 1 || Details::HappyTreeFriends::getRoot(*this) == n) &&
"workaround below assumes root is stored as first element");
return (n > 1 ? &_internal_nodes.data()->bounding_volume
: &_leaf_nodes.data()->bounding_volume);
: &_leaf_nodes.data()->value.bounding_volume);
}

size_type _size{0};
Expand Down Expand Up @@ -212,7 +213,8 @@ template <typename MemorySpace, typename BoundingVolume, typename Enable>
template <typename ExecutionSpace, typename Predicates, typename Callback>
void BasicBoundingVolumeHierarchy<MemorySpace, BoundingVolume, Enable>::query(
ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Experimental::TraversalPolicy const &policy) const
Callback const &legacy_callback,
Experimental::TraversalPolicy const &policy) const
{
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);
Expand All @@ -221,7 +223,9 @@ void BasicBoundingVolumeHierarchy<MemorySpace, BoundingVolume, Enable>::query(
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
Details::check_valid_callback(callback, predicates);
Details::check_valid_callback(legacy_callback, predicates);
Details::LegacyCallbackWrapper<Callback, value_type> callback{
legacy_callback};

using Tag = typename Details::AccessTraitsHelper<Access>::tag;
std::string profiling_prefix = "ArborX::BVH::query::";
Expand Down
13 changes: 13 additions & 0 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,19 @@ void check_valid_callback(Callback const &callback, Predicates const &)
"Callback 'operator()' return type must be void");
}

template <typename Callback, typename Value>
struct LegacyCallbackWrapper
{
Callback _callback;

template <typename Predicate>
KOKKOS_FUNCTION auto operator()(Predicate const &predicate,
Value const &value) const
{
return _callback(predicate, value.index);
}
};

} // namespace Details
} // namespace ArborX

Expand Down
3 changes: 1 addition & 2 deletions src/details/ArborX_DetailsDistributedTreeImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,7 @@ struct CallbackWithDistance
"compute_reverse_permutation",
Kokkos::RangePolicy<ExecutionSpace>(exec_space, 0, n),
ARBORX_CLASS_LAMBDA(int const i) {
_rev_permute(HappyTreeFriends::getLeafPermutationIndex(_tree, i)) =
i;
_rev_permute(HappyTreeFriends::getValue(_tree, i).index) = i;
});
}
}
Expand Down
9 changes: 4 additions & 5 deletions src/details/ArborX_DetailsHalfTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef ARBORX_DETAILS_HALF_TRAVERSAL_HPP
#define ARBORX_DETAILS_HALF_TRAVERSAL_HPP

#include <ArborX_Callbacks.hpp> // LegacyCallbackWrapper
#include <ArborX_DetailsHappyTreeFriends.hpp>
#include <ArborX_DetailsNode.hpp> // ROPE_SENTINEL

Expand All @@ -25,7 +26,7 @@ struct HalfTraversal
{
BVH _bvh;
PredicateGetter _get_predicate;
Callback _callback;
LegacyCallbackWrapper<Callback, typename BVH::value_type> _callback;

template <class ExecutionSpace>
HalfTraversal(ExecutionSpace const &space, BVH const &bvh,
Expand Down Expand Up @@ -54,8 +55,7 @@ struct HalfTraversal
{
auto const predicate =
_get_predicate(HappyTreeFriends::getLeafBoundingVolume(_bvh, i));
auto const leaf_permutation_i =
HappyTreeFriends::getLeafPermutationIndex(_bvh, i);
auto const leaf_permutation_i = HappyTreeFriends::getValue(_bvh, i).index;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yuk. I don't have a better idea though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, would be nice to just have getValue() here. If we ever figure out how to move away from storing (index, volume) together in the future, this would be possible. But right now, it comes with severe penalties.


int node = HappyTreeFriends::getRope(_bvh, i);
while (node != ROPE_SENTINEL)
Expand All @@ -69,8 +69,7 @@ struct HalfTraversal
{
if (is_leaf)
{
_callback(leaf_permutation_i,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node));
_callback(leaf_permutation_i, HappyTreeFriends::getValue(_bvh, node));
node = HappyTreeFriends::getRope(_bvh, node);
}
else
Expand Down
8 changes: 4 additions & 4 deletions src/details/ArborX_DetailsHappyTreeFriends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ struct HappyTreeFriends
{
static_assert(
std::is_same_v<decltype(bvh._internal_nodes(0).bounding_volume),
decltype(bvh._leaf_nodes(0).bounding_volume)>);
return bvh._leaf_nodes(i).bounding_volume;
decltype(bvh._leaf_nodes(0).value.bounding_volume)>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still want the static assert here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now. Do you have a better idea or place for it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have dropped it since this is not something that user input can break but I don't feel too strongly about it.

return bvh._leaf_nodes(i).value.bounding_volume;
}

template <class BVH>
static KOKKOS_FUNCTION auto getLeafPermutationIndex(BVH const &bvh, int i)
static KOKKOS_FUNCTION auto const &getValue(BVH const &bvh, int i)
{
assert(i >= 0 && i < (int)bvh.size());
return bvh._leaf_nodes(i).permutation_index;
return bvh._leaf_nodes(i).value;
}

template <class BVH>
Expand Down
30 changes: 16 additions & 14 deletions src/details/ArborX_DetailsNode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@
#include <Kokkos_Macros.hpp>

#include <cassert>
#include <climits> // UINT_MAX
#include <utility> // std::move

namespace ArborX
{
namespace Details
namespace ArborX::Details
{

constexpr int ROPE_SENTINEL = -1;

template <class BoundingVolume>
struct PairIndexVolume
{
unsigned index;
BoundingVolume bounding_volume;
};

template <class Value>
struct LeafNode
{
using bounding_volume_type = BoundingVolume;
using value_type = Value;

unsigned permutation_index = UINT_MAX;
int rope = ROPE_SENTINEL;
BoundingVolume bounding_volume;
Value value;
};

template <class BoundingVolume>
Expand All @@ -48,14 +51,13 @@ struct InternalNode
BoundingVolume bounding_volume;
};

template <class BoundingVolume>
KOKKOS_INLINE_FUNCTION constexpr LeafNode<BoundingVolume>
makeLeafNode(unsigned permutation_index,
BoundingVolume bounding_volume) noexcept
template <class Value>
KOKKOS_INLINE_FUNCTION constexpr LeafNode<Value>
makeLeafNode(Value value) noexcept
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like so much spelling the template parameter at the call site but I suppose this way you won't have to change the definition later.

{
return {permutation_index, ROPE_SENTINEL, std::move(bounding_volume)};
return {ROPE_SENTINEL, std::move(value)};
}
} // namespace Details
} // namespace ArborX

} // namespace ArborX::Details

#endif
14 changes: 7 additions & 7 deletions src/details/ArborX_DetailsTreeConstruction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,18 @@ inline void initializeSingleLeafNode(ExecutionSpace const &space,
Nodes const &leaf_nodes)
{
using Access = AccessTraits<Primitives, PrimitivesTag>;
using Value = typename Nodes::value_type::value_type;
using BoundingVolume = decltype(std::declval<Value>().bounding_volume);

ARBORX_ASSERT(leaf_nodes.extent(0) == 1);
ARBORX_ASSERT(Access::size(primitives) == 1);

using Node = typename Nodes::value_type;
using BoundingVolume = typename Node::bounding_volume_type;

Kokkos::parallel_for(
"ArborX::TreeConstruction::initialize_single_leaf",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, 1), KOKKOS_LAMBDA(int) {
BoundingVolume bounding_volume{};
expand(bounding_volume, Access::get(primitives, 0));
leaf_nodes(0) = makeLeafNode(0, std::move(bounding_volume));
leaf_nodes(0) = makeLeafNode(Value{(unsigned)0, bounding_volume});
});
}

Expand Down Expand Up @@ -209,8 +208,9 @@ class GenerateHierarchy
expand(bounding_volume, Access::get(_primitives, original_index));

// Initialize leaf node
using Value = typename LeafNodes::value_type::value_type;
auto &leaf_node = _leaf_nodes(i);
leaf_node = makeLeafNode(original_index, bounding_volume);
leaf_node = makeLeafNode(Value{original_index, bounding_volume});

// For a leaf node, the range is just one index
int range_left = i;
Expand Down Expand Up @@ -271,7 +271,7 @@ class GenerateHierarchy
Kokkos::load_fence();
expand(bounding_volume,
right_child_is_leaf
? _leaf_nodes(right_child).bounding_volume
? _leaf_nodes(right_child).value.bounding_volume
: _internal_nodes(right_child).bounding_volume);
}
else
Expand All @@ -294,7 +294,7 @@ class GenerateHierarchy
Kokkos::load_fence();
expand(bounding_volume,
left_child_is_leaf
? _leaf_nodes(left_child).bounding_volume
? _leaf_nodes(left_child).value.bounding_volume
: _internal_nodes(left_child).bounding_volume);

if (!left_child_is_leaf)
Expand Down
17 changes: 7 additions & 10 deletions src/details/ArborX_DetailsTreeTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
HappyTreeFriends::getLeafBoundingVolume(_bvh, root);
if (predicate(root_bounding_volume))
{
_callback(predicate, 0);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}
}

Expand All @@ -102,8 +102,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
if (is_leaf)
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node)))
_callback, predicate, HappyTreeFriends::getValue(_bvh, node)))
return;
node = HappyTreeFriends::getRope(_bvh, node);
}
Expand Down Expand Up @@ -218,7 +217,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
if (k < 1)
return;

_callback(predicate, 0);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}

KOKKOS_FUNCTION void operator()(int queryIndex) const
Expand Down Expand Up @@ -378,9 +377,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
sortHeap(heap.data(), heap.data() + heap.size(), heap.valueComp());
for (decltype(heap.size()) i = 0; i < heap.size(); ++i)
{
int const leaf_index = HappyTreeFriends::getLeafPermutationIndex(
_bvh, (heap.data() + i)->first);
_callback(predicate, leaf_index);
_callback(predicate,
HappyTreeFriends::getValue(_bvh, (heap.data() + i)->first));
}
}
};
Expand Down Expand Up @@ -440,7 +438,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
KokkosExt::ArithmeticTraits::infinity<distance_type>::value;
if (distance(getGeometry(predicate), root_bounding_volume) != inf)
{
_callback(predicate, 0);
_callback(predicate, HappyTreeFriends::getValue(_bvh, 0));
}
}

Expand Down Expand Up @@ -487,8 +485,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
if (HappyTreeFriends::isLeaf(_bvh, node))
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node)))
_callback, predicate, HappyTreeFriends::getValue(_bvh, node)))
return;

if (heap.empty())
Expand Down
10 changes: 5 additions & 5 deletions src/details/ArborX_DetailsTreeVisualization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ struct TreeVisualization
{
auto const node_is_leaf = HappyTreeFriends::isLeaf(tree, node);
auto const node_index =
node_is_leaf ? HappyTreeFriends::getLeafPermutationIndex(tree, node)
: node;
node_is_leaf ? HappyTreeFriends::getValue(tree, node).index : node;
std::string label = node_is_leaf ? "l" : "i";
label.append(std::to_string(node_index));
return label;
Expand Down Expand Up @@ -154,9 +153,10 @@ struct TreeVisualization
struct VisitorCallback
{
template <typename Query>
KOKKOS_FUNCTION void operator()(Query const &, int index) const
KOKKOS_FUNCTION void
operator()(Query const &, typename TreeType::value_type const &value) const
{
_visitor.visit(_tree, permute(index));
_visitor.visit(_tree, permute(value.index));
}

TreeType _tree;
Expand Down Expand Up @@ -191,7 +191,7 @@ struct TreeVisualization
Kokkos::parallel_for(
"ArborX::Viz::compute_permutation",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n), KOKKOS_LAMBDA(int i) {
permute(HappyTreeFriends::getLeafPermutationIndex(tree, i)) = i;
permute(HappyTreeFriends::getValue(tree, i).index) = i;
});

Predicates predicates(Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
Expand Down
Loading