Skip to content

Commit

Permalink
Private get source view
Browse files Browse the repository at this point in the history
  • Loading branch information
mrlag31 committed Dec 27, 2023
1 parent fc7a2ed commit 2f814b5
Showing 1 changed file with 44 additions and 29 deletions.
73 changes: 44 additions & 29 deletions src/interpolation/ArborX_InterpMovingLeastSquares.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ struct MLSTargetPointsPredicateWrapper
int num_neighbors;
};

// Private kernel to compute the 2D source view
template <typename SourceView, typename SourcePoints, typename Indices>
struct GetSourceViewKernel
{
SourceView source_view;
ArborX::Details::AccessValues<SourcePoints, PrimitivesTag> source_access;
Indices indices;
int num_neighbors;

KOKKOS_FUNCTION void operator()(int const i, int const j) const
{
source_view(i, j) = source_access(indices(i * num_neighbors + j));
}
};

} // namespace ArborX::Interpolation::Details

namespace ArborX
Expand Down Expand Up @@ -159,35 +174,6 @@ class MovingLeastSquares
space, source_view, target_points);
}

template <typename ExecutionSpace, typename SourcePoints>
Kokkos::View<typename ArborX::Details::AccessTraitsHelper<
AccessTraits<SourcePoints, PrimitivesTag>>::type **,
MemorySpace>
getSourceView(ExecutionSpace const &space, SourcePoints const &source_points)
{
auto guard = Kokkos::Profiling::ScopedRegion(
"ArborX::MovingLeastSquares::getSourceView");

using SourceAccess =
ArborX::Details::AccessValues<SourcePoints, PrimitivesTag>;
using SourcePoint = typename SourceAccess::value_type;

Kokkos::View<SourcePoint **, MemorySpace> source_view(
Kokkos::view_alloc(Kokkos::WithoutInitializing,
"ArborX::MovingLeastSquares::source_view"),
_num_targets, _num_neighbors);
SourceAccess source_access{source_points};
Kokkos::parallel_for(
"ArborX::MovingLeastSquares::values_indices_and_source_view_fill",
Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<2>>(
space, {0, 0}, {_num_targets, _num_neighbors}),
KOKKOS_CLASS_LAMBDA(int const i, int const j) {
source_view(i, j) = source_access(_indices(i * _num_neighbors + j));
});

return source_view;
}

template <typename ExecutionSpace, typename SourceValues,
typename ApproxValues>
void interpolate(ExecutionSpace const &space,
Expand Down Expand Up @@ -243,6 +229,35 @@ class MovingLeastSquares
}

private:
template <typename ExecutionSpace, typename SourcePoints>
auto getSourceView(ExecutionSpace const &space,
SourcePoints const &source_points)
{
auto guard = Kokkos::Profiling::ScopedRegion(
"ArborX::MovingLeastSquares::getSourceView");

using SourceAccess =
ArborX::Details::AccessValues<SourcePoints, PrimitivesTag>;
using SourcePoint = typename SourceAccess::value_type;

Kokkos::View<SourcePoint **, MemorySpace> source_view(
Kokkos::view_alloc(Kokkos::WithoutInitializing,
"ArborX::MovingLeastSquares::source_view"),
_num_targets, _num_neighbors);

Details::GetSourceViewKernel<decltype(source_view), SourcePoints,
decltype(_indices)>
kernel{source_view, {source_points}, _indices, _num_neighbors};

Kokkos::parallel_for(
"ArborX::MovingLeastSquares::values_indices_and_source_view_fill",
Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<2>>(
space, {0, 0}, {_num_targets, _num_neighbors}),
kernel);

return source_view;
}

Kokkos::View<FloatingCalculationType **, MemorySpace> _coeffs;
Kokkos::View<int *, MemorySpace> _indices;
int _num_targets;
Expand Down

0 comments on commit 2f814b5

Please sign in to comment.