Skip to content

Commit

Permalink
Using AccessValues instead of AccessTraits
Browse files Browse the repository at this point in the history
  • Loading branch information
mrlag31 committed Dec 27, 2023
1 parent 6d907d9 commit 26ba073
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 44 deletions.
71 changes: 37 additions & 34 deletions src/interpolation/ArborX_InterpMovingLeastSquares.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ namespace ArborX::Interpolation::Details
{

// This is done to avoid a clash with another predicates access trait
template <typename Points>
template <typename TargetPoints>
struct MLSTargetPointsPredicateWrapper
{
Points target_points;
ArborX::Details::AccessValues<TargetPoints, PrimitivesTag> target_access;
int num_neighbors;
};

Expand All @@ -51,20 +51,18 @@ struct AccessTraits<
KOKKOS_FUNCTION static auto size(
Interpolation::Details::MLSTargetPointsPredicateWrapper<Points> const &tp)
{
return AccessTraits<Points, PrimitivesTag>::size(tp.target_points);
return tp.target_access.size();
}

KOKKOS_FUNCTION static auto
get(Interpolation::Details::MLSTargetPointsPredicateWrapper<Points> const &tp,
int const i)
{
return nearest(
AccessTraits<Points, PrimitivesTag>::get(tp.target_points, i),
tp.num_neighbors);
return nearest(tp.target_access(i), tp.num_neighbors);
}

using memory_space =
typename AccessTraits<Points, PrimitivesTag>::memory_space;
typename Details::AccessValues<Points, PrimitivesTag>::memory_space;
};

} // namespace ArborX
Expand Down Expand Up @@ -95,48 +93,52 @@ class MovingLeastSquares

// SourcePoints is an access trait of points
ArborX::Details::check_valid_access_traits(PrimitivesTag{}, source_points);
using src_acc = AccessTraits<SourcePoints, PrimitivesTag>;
static_assert(KokkosExt::is_accessible_from<typename src_acc::memory_space,
ExecutionSpace>::value,
"Source points must be accessible from the execution space");
using src_point =
typename ArborX::Details::AccessTraitsHelper<src_acc>::type;
GeometryTraits::check_valid_geometry_traits(src_point{});
static_assert(GeometryTraits::is_point<src_point>::value,
using SourceAccess =
ArborX::Details::AccessValues<SourcePoints, PrimitivesTag>;
static_assert(
KokkosExt::is_accessible_from<typename SourceAccess::memory_space,
ExecutionSpace>::value,
"Source points must be accessible from the execution space");
using SourcePoint = typename SourceAccess::value_type;
GeometryTraits::check_valid_geometry_traits(SourcePoint{});
static_assert(GeometryTraits::is_point<SourcePoint>::value,
"Source points elements must be points");
static constexpr int dimension = GeometryTraits::dimension_v<src_point>;
static constexpr int dimension = GeometryTraits::dimension_v<SourcePoint>;

// TargetPoints is an access trait of points
ArborX::Details::check_valid_access_traits(PrimitivesTag{}, target_points);
using tgt_acc = AccessTraits<TargetPoints, PrimitivesTag>;
static_assert(KokkosExt::is_accessible_from<typename tgt_acc::memory_space,
ExecutionSpace>::value,
"Target points must be accessible from the execution space");
using tgt_point =
typename ArborX::Details::AccessTraitsHelper<tgt_acc>::type;
GeometryTraits::check_valid_geometry_traits(tgt_point{});
static_assert(GeometryTraits::is_point<tgt_point>::value,
using TargetAccess =
ArborX::Details::AccessValues<TargetPoints, PrimitivesTag>;
static_assert(
KokkosExt::is_accessible_from<typename TargetAccess::memory_space,
ExecutionSpace>::value,
"Target points must be accessible from the execution space");
using TargetPoint = typename TargetAccess::value_type;
GeometryTraits::check_valid_geometry_traits(TargetPoint{});
static_assert(GeometryTraits::is_point<TargetPoint>::value,
"Target points elements must be points");
static_assert(dimension == GeometryTraits::dimension_v<tgt_point>,
static_assert(dimension == GeometryTraits::dimension_v<TargetPoint>,
"Target and source points must have the same dimension");

int num_neighbors_val =
num_neighbors ? *num_neighbors
: Details::polynomialBasisSize<dimension,
PolynomialDegree::value>();

int const num_targets = tgt_acc::size(target_points);
_source_size = src_acc::size(source_points);
TargetAccess target_access{target_points};
SourceAccess source_access{source_points};
int const num_targets = target_access.size();
_source_size = source_access.size();
// There must be enough source points
KOKKOS_ASSERT(0 < num_neighbors_val && num_neighbors_val <= _source_size);

// Organize the source points as a tree
BoundingVolumeHierarchy<MemorySpace, ArborX::PairValueIndex<src_point>>
BoundingVolumeHierarchy<MemorySpace, ArborX::PairValueIndex<SourcePoint>>
source_tree(space, ArborX::Experimental::attach_indices(source_points));

// Create the predicates
Details::MLSTargetPointsPredicateWrapper<TargetPoints> predicates{
target_points, num_neighbors_val};
{target_points}, num_neighbors_val};

// Query the source
Kokkos::View<int *, MemorySpace> indices(
Expand Down Expand Up @@ -171,26 +173,27 @@ class MovingLeastSquares
auto guard = Kokkos::Profiling::ScopedRegion(
"ArborX::MovingLeastSquares::fillValuesIndicesAndGetSourceView");

using src_acc = AccessTraits<SourcePoints, PrimitivesTag>;
using src_point =
typename ArborX::Details::AccessTraitsHelper<src_acc>::type;
using SourceAccess =
ArborX::Details::AccessValues<SourcePoints, PrimitivesTag>;
using SourcePoint = typename SourceAccess::value_type;

_values_indices = Kokkos::View<int **, MemorySpace>(
Kokkos::view_alloc(Kokkos::WithoutInitializing,
"ArborX::MovingLeastSquares::values_indices"),
num_targets, num_neighbors);
Kokkos::View<src_point **, MemorySpace> source_view(
Kokkos::View<SourcePoint **, MemorySpace> source_view(
Kokkos::view_alloc(Kokkos::WithoutInitializing,
"ArborX::MovingLeastSquares::source_view"),
num_targets, num_neighbors);
SourceAccess source_access{source_points};
Kokkos::parallel_for(
"ArborX::MovingLeastSquares::values_indices_and_source_view_fill",
Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<2>>(
space, {0, 0}, {num_targets, num_neighbors}),
KOKKOS_CLASS_LAMBDA(int const i, int const j) {
auto index = indices(offsets(i) + j);
_values_indices(i, j) = index;
source_view(i, j) = src_acc::get(source_points, index);
source_view(i, j) = source_access(index);
});

return source_view;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,21 @@ movingLeastSquaresCoefficients(ExecutionSpace const &space,

// TargetPoints is an access trait of points
ArborX::Details::check_valid_access_traits(PrimitivesTag{}, target_points);
using tgt_acc = AccessTraits<TargetPoints, PrimitivesTag>;
static_assert(KokkosExt::is_accessible_from<typename tgt_acc::memory_space,
ExecutionSpace>::value,
"target points must be accessible from the execution space");
using tgt_point = typename ArborX::Details::AccessTraitsHelper<tgt_acc>::type;
GeometryTraits::check_valid_geometry_traits(tgt_point{});
static_assert(GeometryTraits::is_point<tgt_point>::value,
using TargetAccess =
ArborX::Details::AccessValues<TargetPoints, PrimitivesTag>;
static_assert(
KokkosExt::is_accessible_from<typename TargetAccess::memory_space,
ExecutionSpace>::value,
"target points must be accessible from the execution space");
using TargetPoint = typename TargetAccess::value_type;
GeometryTraits::check_valid_geometry_traits(TargetPoint{});
static_assert(GeometryTraits::is_point<TargetPoint>::value,
"target points elements must be points");
static_assert(dimension == GeometryTraits::dimension_v<tgt_point>,
static_assert(dimension == GeometryTraits::dimension_v<TargetPoint>,
"target and source points must have the same dimension");

int const num_targets = tgt_acc::size(target_points);
TargetAccess target_access{target_points};
int const num_targets = target_access.size();
int const num_neighbors = source_points.extent(1);

// There must be a set of neighbors for each target
Expand Down Expand Up @@ -104,7 +107,7 @@ movingLeastSquaresCoefficients(ExecutionSpace const &space,
space, {0, 0}, {num_targets, num_neighbors}),
KOKKOS_LAMBDA(int const i, int const j) {
auto src = source_points(i, j);
auto tgt = tgt_acc::get(target_points, i);
auto tgt = target_access(i);
point_t t{};

for (int k = 0; k < dimension; k++)
Expand Down

0 comments on commit 26ba073

Please sign in to comment.