Skip to content

Commit

Permalink
Create a helper file for distributed nearest stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed May 3, 2024
1 parent 8c17cd7 commit b8c0a24
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 193 deletions.
196 changes: 3 additions & 193 deletions src/details/ArborX_DetailsDistributedTreeNearest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,101 +11,22 @@
#ifndef ARBORX_DETAILS_DISTRIBUTED_TREE_NEAREST_HPP
#define ARBORX_DETAILS_DISTRIBUTED_TREE_NEAREST_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_DetailsDistributedTreeImpl.hpp>
#include <ArborX_DetailsDistributedTreeNearestHelpers.hpp>
#include <ArborX_DetailsDistributedTreeUtils.hpp>
#include <ArborX_DetailsHappyTreeFriends.hpp>
#include <ArborX_DetailsKokkosExtKernelStdAlgorithms.hpp>
#include <ArborX_DetailsKokkosExtMinMaxOperations.hpp>
#include <ArborX_DetailsKokkosExtStdAlgorithms.hpp>
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
#include <ArborX_LinearBVH.hpp>
#include <ArborX_Point.hpp>
#include <ArborX_Predicates.hpp>
#include <ArborX_Ray.hpp>
#include <ArborX_Sphere.hpp>

#include <Kokkos_Core.hpp>
#include <Kokkos_Profiling_ScopedRegion.hpp>

// Don't really need it, but our self containment tests rely on its presence
#include <mpi.h>

namespace ArborX
{
namespace Details
{
template <class Predicates, class Distances>
struct WithinDistanceFromPredicates
{
Predicates predicates;
Distances distances;
};
} // namespace Details

template <class Predicates, class Distances>
struct AccessTraits<
Details::WithinDistanceFromPredicates<Predicates, Distances>, PredicatesTag>
{
using Predicate = typename Predicates::value_type;
using Geometry =
std::decay_t<decltype(getGeometry(std::declval<Predicate const &>()))>;
using Self = Details::WithinDistanceFromPredicates<Predicates, Distances>;

using memory_space = typename Predicates::memory_space;
using size_type = decltype(std::declval<Predicates const &>().size());

static KOKKOS_FUNCTION size_type size(Self const &x)
{
return x.predicates.size();
}
template <class Dummy = Geometry,
std::enable_if_t<std::is_same_v<Dummy, Geometry> &&
std::is_same_v<Dummy, Point>> * = nullptr>
static KOKKOS_FUNCTION auto get(Self const &x, size_type i)
{
auto const point = getGeometry(x.predicates(i));
auto const distance = x.distances(i);
return intersects(Sphere{point, distance});
}
template <class Dummy = Geometry,
std::enable_if_t<std::is_same_v<Dummy, Geometry> &&
std::is_same_v<Dummy, Box>> * = nullptr>
static KOKKOS_FUNCTION auto get(Self const &x, size_type i)
{
auto box = getGeometry(x.predicates(i));
auto &min_corner = box.minCorner();
auto &max_corner = box.maxCorner();
auto const distance = x.distances(i);
for (int d = 0; d < 3; ++d)
{
min_corner[d] -= distance;
max_corner[d] += distance;
}
return intersects(box);
}
template <class Dummy = Geometry,
std::enable_if_t<std::is_same_v<Dummy, Geometry> &&
std::is_same_v<Dummy, Sphere>> * = nullptr>
static KOKKOS_FUNCTION auto get(Self const &x, size_type i)
{
auto const sphere = getGeometry(x.predicates(i));
auto const distance = x.distances(i);
return intersects(Sphere{sphere.centroid(), distance + sphere.radius()});
}
template <
class Dummy = Geometry,
std::enable_if_t<std::is_same_v<Dummy, Geometry> &&
std::is_same_v<Dummy, Experimental::Ray>> * = nullptr>
static KOKKOS_FUNCTION auto get(Self const &x, size_type i)
{
auto const ray = getGeometry(x.predicates(i));
return intersects(ray);
}
};

namespace Details
namespace ArborX::Details
{

template <typename Value>
Expand All @@ -115,116 +36,6 @@ struct PairValueDistance
float distance;
};

template <typename Tree, typename Callback, typename OutValue, bool UseValues>
struct CallbackWithDistance
{
Tree _tree;
Callback _callback;

template <typename ExecutionSpace>
CallbackWithDistance(ExecutionSpace const &, Tree const &tree,
Callback const &callback)
: _tree(tree)
, _callback(callback)
{}

template <typename Query, typename Value, typename Output>
KOKKOS_FUNCTION void operator()(Query const &query, Value const &value,
Output const &out) const
{
if constexpr (UseValues)
{
OutValue out_value;
int count = 0;
_callback(query, value, [&](OutValue const &ov) {
out_value = ov;
++count;
});
KOKKOS_ASSERT(count == 1);
out({out_value,
distance(getGeometry(query), _tree.indexable_get()(value))});
}
else
out(distance(getGeometry(query), _tree.indexable_get()(value)));
}
};

template <typename MemorySpace, typename Callback, typename OutValue,
bool UseValues>
struct CallbackWithDistance<
BoundingVolumeHierarchy<MemorySpace, Details::LegacyDefaultTemplateValue,
Details::DefaultIndexableGetter,
ExperimentalHyperGeometry::Box<3, float>>,
Callback, OutValue, UseValues>
{
using Tree =
BoundingVolumeHierarchy<MemorySpace, Details::LegacyDefaultTemplateValue,
Details::DefaultIndexableGetter,
ExperimentalHyperGeometry::Box<3, float>>;

Tree _tree;
Callback _callback;
Kokkos::View<unsigned int *, typename Tree::memory_space> _rev_permute;

template <typename ExecutionSpace>
CallbackWithDistance(ExecutionSpace const &exec_space, Tree const &tree,
Callback const &callback)
: _tree(tree)
, _callback(callback)
{
// NOTE cannot have extended __host__ __device__ lambda in constructor with
// NVCC
computeReversePermutation(exec_space);
}

template <typename ExecutionSpace>
void computeReversePermutation(ExecutionSpace const &exec_space)
{
auto const n = _tree.size();

_rev_permute = Kokkos::View<unsigned int *, typename Tree::memory_space>(
Kokkos::view_alloc(
Kokkos::WithoutInitializing,
"ArborX::DistributedTree::query::nearest::reverse_permutation"),
n);
if (!_tree.empty())
{
Kokkos::parallel_for(
"ArborX::DistributedTree::query::nearest::"
"compute_reverse_permutation",
Kokkos::RangePolicy<ExecutionSpace>(exec_space, 0, n),
KOKKOS_CLASS_LAMBDA(int const i) {
_rev_permute(HappyTreeFriends::getValue(_tree, i).index) = i;
});
}
}

template <typename Query, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Query const &query, int index,
OutputFunctor const &out) const
{
// TODO: This breaks the abstraction of the distributed Tree not knowing
// the details of the local tree. Right now, this is the only way. Will
// need to be fixed with a proper callback abstraction.
int const leaf_node_index = _rev_permute(index);
auto const &leaf_node_bounding_volume =
HappyTreeFriends::getIndexable(_tree, leaf_node_index);
if constexpr (UseValues)
{
OutValue out_value;
int count = 0;
_callback(query, index, [&](OutValue const &ov) {
out_value = ov;
++count;
});
KOKKOS_ASSERT(count == 1);
out({out_value, distance(getGeometry(query), leaf_node_bounding_volume)});
}
else
out(distance(getGeometry(query), leaf_node_bounding_volume));
}
};

template <typename ExecutionSpace, typename Tree, typename Predicates,
typename Distances>
void DistributedTreeImpl::phaseI(ExecutionSpace const &space, Tree const &tree,
Expand Down Expand Up @@ -428,7 +239,6 @@ DistributedTreeImpl::queryDispatch(NearestPredicateTag tag, Tree const &tree,
queryDispatchImpl(tag, tree, space, predicates, callback, values, offset);
}

} // namespace Details
} // namespace ArborX
} // namespace ArborX::Details

#endif
Loading

0 comments on commit b8c0a24

Please sign in to comment.