Skip to content

Commit

Permalink
Passing around AccessValues instead of raw user inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mrlag31 committed Dec 27, 2023
1 parent 9933846 commit a1f89a5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
44 changes: 21 additions & 23 deletions src/interpolation/ArborX_InterpMovingLeastSquares.hpp
Expand Up @@ -31,10 +31,10 @@ namespace ArborX::Interpolation::Details
{

// This is done to avoid a clash with another predicates access trait
template <typename TargetPoints>
template <typename TargetAccess>
struct MLSTargetPointsPredicateWrapper
{
ArborX::Details::AccessValues<TargetPoints, PrimitivesTag> target_access;
TargetAccess target_access;
int num_neighbors;
};

Expand Down Expand Up @@ -66,26 +66,25 @@ struct SearchNeighborsCallback
namespace ArborX
{

template <typename Points>
template <typename TargetAccess>
struct AccessTraits<
Interpolation::Details::MLSTargetPointsPredicateWrapper<Points>,
Interpolation::Details::MLSTargetPointsPredicateWrapper<TargetAccess>,
PredicatesTag>
{
KOKKOS_FUNCTION static auto size(
Interpolation::Details::MLSTargetPointsPredicateWrapper<Points> const &tp)
using Self =
Interpolation::Details::MLSTargetPointsPredicateWrapper<TargetAccess>;

KOKKOS_FUNCTION static auto size(Self const &tp)
{
return tp.target_access.size();
}

KOKKOS_FUNCTION static auto
get(Interpolation::Details::MLSTargetPointsPredicateWrapper<Points> const &tp,
int const i)
KOKKOS_FUNCTION static auto get(Self const &tp, int const i)
{
return attach(nearest(tp.target_access(i), tp.num_neighbors), i);
}

using memory_space =
typename Details::AccessValues<Points, PrimitivesTag>::memory_space;
using memory_space = typename TargetAccess::memory_space;
};

} // namespace ArborX
Expand Down Expand Up @@ -150,18 +149,19 @@ class MovingLeastSquares

TargetAccess target_access{target_points};
SourceAccess source_access{source_points};

_num_targets = target_access.size();
_source_size = source_access.size();
// There must be enough source points
KOKKOS_ASSERT(0 < _num_neighbors && _num_neighbors <= _source_size);

// Search for neighbors and get the arranged source points
auto source_view = searchNeighbors(space, source_points, target_points);
auto source_view = searchNeighbors(space, source_access, target_access);

// Compute the moving least squares coefficients
_coeffs = Details::movingLeastSquaresCoefficients<CRBFunc, PolynomialDegree,
FloatingCalculationType>(
space, source_view, target_points);
space, source_view, target_access._values);
}

template <typename ExecutionSpace, typename SourceValues,
Expand Down Expand Up @@ -218,25 +218,23 @@ class MovingLeastSquares
}

private:
template <typename ExecutionSpace, typename SourcePoints,
typename TargetPoints>
template <typename ExecutionSpace, typename SourceAccess,
typename TargetAccess>
auto searchNeighbors(ExecutionSpace const &space,
SourcePoints const &source_points,
TargetPoints const &target_points)
SourceAccess const &source_access,
TargetAccess const &target_access)
{
auto guard = Kokkos::Profiling::ScopedRegion(
"ArborX::MovingLeastSquares::searchNeighbors");

// Organize the source points as a tree
using SourcePoint =
typename ArborX::Details::AccessValues<SourcePoints,
PrimitivesTag>::value_type;
using SourcePoint = typename SourceAccess::value_type;
BoundingVolumeHierarchy<MemorySpace, ArborX::PairValueIndex<SourcePoint>>
source_tree(space, ArborX::Experimental::attach_indices(source_points));
source_tree(space, ArborX::Experimental::attach_indices(source_access));

// Create the predicates
Details::MLSTargetPointsPredicateWrapper<TargetPoints> predicates{
{target_points}, _num_neighbors};
Details::MLSTargetPointsPredicateWrapper<TargetAccess> predicates{
target_access, _num_neighbors};

// Create the callback
Kokkos::View<SourcePoint **, MemorySpace> source_view(
Expand Down
Expand Up @@ -28,11 +28,11 @@ namespace ArborX::Interpolation::Details

template <typename CRBFunc, typename PolynomialDegree,
typename CoefficientsType, typename ExecutionSpace,
typename SourcePoints, typename TargetPoints>
typename SourcePoints, typename TargetAccess>
Kokkos::View<CoefficientsType **, typename SourcePoints::memory_space>
movingLeastSquaresCoefficients(ExecutionSpace const &space,
SourcePoints const &source_points,
TargetPoints const &target_points)
TargetAccess const &target_access)
{
auto guard =
Kokkos::Profiling::ScopedRegion("ArborX::MovingLeastSquaresCoefficients");
Expand All @@ -57,22 +57,18 @@ movingLeastSquaresCoefficients(ExecutionSpace const &space,
"source points elements must be points");
static constexpr int dimension = GeometryTraits::dimension_v<SourcePoint>;

// TargetPoints is an access trait of points
ArborX::Details::check_valid_access_traits(PrimitivesTag{}, target_points);
using TargetAccess =
ArborX::Details::AccessValues<TargetPoints, PrimitivesTag>;
// TargetAccess is an access values of points
static_assert(
KokkosExt::is_accessible_from<typename TargetAccess::memory_space,
ExecutionSpace>::value,
"target points must be accessible from the execution space");
"target access must be accessible from the execution space");
using TargetPoint = typename TargetAccess::value_type;
GeometryTraits::check_valid_geometry_traits(TargetPoint{});
static_assert(GeometryTraits::is_point<TargetPoint>::value,
"target points elements must be points");
"target access elements must be points");
static_assert(dimension == GeometryTraits::dimension_v<TargetPoint>,
"target and source points must have the same dimension");

TargetAccess target_access{target_points};
int const num_targets = target_access.size();
int const num_neighbors = source_points.extent(1);

Expand Down

0 comments on commit a1f89a5

Please sign in to comment.