Skip to content

Commit

Permalink
Moving source view computation in private callback
Browse files Browse the repository at this point in the history
Using primitives instead of access values in callback
  • Loading branch information
mrlag31 committed Dec 27, 2023
1 parent 8d9abeb commit 9933846
Showing 1 changed file with 41 additions and 56 deletions.
97 changes: 41 additions & 56 deletions src/interpolation/ArborX_InterpMovingLeastSquares.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,26 @@ struct MLSTargetPointsPredicateWrapper
int num_neighbors;
};

// Private kernel to compute the 2D source view
template <typename SourceView, typename SourcePoints, typename Indices>
struct GetSourceViewKernel
// Functor used in the tree query to create the 2D source view and indices
template <typename SourceView, typename IndicesView, typename CounterView>
struct SearchNeighborsCallback
{
SourceView source_view;
ArborX::Details::AccessValues<SourcePoints, PrimitivesTag> source_access;
Indices indices;
int num_neighbors;
IndicesView indices;
CounterView counter;

using SourcePoint = typename SourceView::non_const_value_type;

KOKKOS_FUNCTION void operator()(int const i, int const j) const
template <typename Predicate>
KOKKOS_FUNCTION void
operator()(Predicate const &predicate,
ArborX::PairValueIndex<SourcePoint> const &primitive) const
{
source_view(i, j) = source_access(indices(i * num_neighbors + j));
int const target = getData(predicate);
int const source = primitive.index;
auto count = Kokkos::atomic_fetch_add(&counter(target), 1);
indices(target, count) = source;
source_view(target, count) = primitive.value;
}
};

Expand All @@ -73,7 +81,7 @@ struct AccessTraits<
get(Interpolation::Details::MLSTargetPointsPredicateWrapper<Points> const &tp,
int const i)
{
return nearest(tp.target_access(i), tp.num_neighbors);
return attach(nearest(tp.target_access(i), tp.num_neighbors), i);
}

using memory_space =
Expand Down Expand Up @@ -147,12 +155,8 @@ class MovingLeastSquares
// There must be enough source points
KOKKOS_ASSERT(0 < _num_neighbors && _num_neighbors <= _source_size);

// Search for neighbors
searchNeighbors(space, source_points, target_points);

// Fill in the value indices object so values can be transferred from a 1D
// source data to a properly distributed 2D array for each target.
auto const source_view = getSourceView(space, source_points);
// Search for neighbors and get the arranged source points
auto source_view = searchNeighbors(space, source_points, target_points);

// Compute the moving least squares coefficients
_coeffs = Details::movingLeastSquaresCoefficients<CRBFunc, PolynomialDegree,
Expand Down Expand Up @@ -208,45 +212,15 @@ class MovingLeastSquares
KOKKOS_CLASS_LAMBDA(int const i) {
Value tmp = 0;
for (int j = 0; j < _num_neighbors; j++)
tmp +=
_coeffs(i, j) * source_values(_indices(i * _num_neighbors + j));
tmp += _coeffs(i, j) * source_values(_indices(i, j));
approx_values(i) = tmp;
});
}

private:
template <typename ExecutionSpace, typename SourcePoints>
auto getSourceView(ExecutionSpace const &space,
SourcePoints const &source_points)
{
auto guard = Kokkos::Profiling::ScopedRegion(
"ArborX::MovingLeastSquares::getSourceView");

using SourceAccess =
ArborX::Details::AccessValues<SourcePoints, PrimitivesTag>;
using SourcePoint = typename SourceAccess::value_type;

Kokkos::View<SourcePoint **, MemorySpace> source_view(
Kokkos::view_alloc(Kokkos::WithoutInitializing,
"ArborX::MovingLeastSquares::source_view"),
_num_targets, _num_neighbors);

Details::GetSourceViewKernel<decltype(source_view), SourcePoints,
decltype(_indices)>
kernel{source_view, {source_points}, _indices, _num_neighbors};

Kokkos::parallel_for(
"ArborX::MovingLeastSquares::values_indices_and_source_view_fill",
Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<2>>(
space, {0, 0}, {_num_targets, _num_neighbors}),
kernel);

return source_view;
}

template <typename ExecutionSpace, typename SourcePoints,
typename TargetPoints>
void searchNeighbors(ExecutionSpace const &space,
auto searchNeighbors(ExecutionSpace const &space,
SourcePoints const &source_points,
TargetPoints const &target_points)
{
Expand All @@ -264,18 +238,29 @@ class MovingLeastSquares
Details::MLSTargetPointsPredicateWrapper<TargetPoints> predicates{
{target_points}, _num_neighbors};

// Query the source
_indices = Kokkos::View<int *, MemorySpace>(
"ArborX::MovingLeastSquares::indices", 0);
Kokkos::View<int *, MemorySpace> offsets(
"ArborX::MovingLeastSquares::offsets", 0);
source_tree.query(space, predicates,
ArborX::Details::LegacyDefaultCallback{}, _indices,
offsets);
// Create the callback
Kokkos::View<SourcePoint **, MemorySpace> source_view(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::MovingLeastSquares::source_view"),
_num_targets, _num_neighbors);
_indices = Kokkos::View<int **, MemorySpace>(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::MovingLeastSquares::indices"),
_num_targets, _num_neighbors);
Kokkos::View<int *, MemorySpace> counter(
"ArborX::MovingLeastSquares::counter", _num_targets);
Details::SearchNeighborsCallback<decltype(source_view), decltype(_indices),
decltype(counter)>
callback{source_view, _indices, counter};

// Query the source tree
source_tree.query(space, predicates, callback);

return source_view;
}

Kokkos::View<FloatingCalculationType **, MemorySpace> _coeffs;
Kokkos::View<int *, MemorySpace> _indices;
Kokkos::View<int **, MemorySpace> _indices;
int _num_targets;
int _num_neighbors;
int _source_size;
Expand Down

0 comments on commit 9933846

Please sign in to comment.