Skip to content

Commit

Permalink
Replace suitable instances with helper predicate creation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Feb 13, 2024
1 parent e97cb4e commit b13835c
Show file tree
Hide file tree
Showing 13 changed files with 49 additions and 295 deletions.
3 changes: 2 additions & 1 deletion benchmarks/dbscan/ArborX_DBSCANVerification.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ bool verifyDBSCAN(ExecutionSpace exec_space, Primitives const &primitives,
ArborX::BoundingVolumeHierarchy<MemorySpace, ArborX::PairValueIndex<Point>>
bvh(exec_space, ArborX::Experimental::attach_indices(points));

auto const predicates = Details::PrimitivesWithRadius<Points>{points, eps};
auto const predicates = ArborX::Experimental::attach_indices(
ArborX::Experimental::intersect_geometries_with_radius(points, eps));

Kokkos::View<int *, MemorySpace> indices("ArborX::DBSCAN::indices", 0);
Kokkos::View<int *, MemorySpace> offset("ArborX::DBSCAN::offset", 0);
Expand Down
58 changes: 7 additions & 51 deletions benchmarks/distributed_tree_driver/distributed_tree_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,52 +170,6 @@ class TimeMonitor
}
};

template <typename DeviceType>
struct NearestNeighborsSearches
{
Kokkos::View<ArborX::Point *, DeviceType> points;
int k;
};
template <typename DeviceType>
struct RadiusSearches
{
Kokkos::View<ArborX::Point *, DeviceType> points;
double radius;
};

template <typename DeviceType>
struct ArborX::AccessTraits<RadiusSearches<DeviceType>, ArborX::PredicatesTag>
{
using memory_space = typename DeviceType::memory_space;
static KOKKOS_FUNCTION std::size_t
size(RadiusSearches<DeviceType> const &pred)
{
return pred.points.extent(0);
}
static KOKKOS_FUNCTION auto get(RadiusSearches<DeviceType> const &pred,
std::size_t i)
{
return ArborX::intersects(ArborX::Sphere{pred.points(i), pred.radius});
}
};

template <typename DeviceType>
struct ArborX::AccessTraits<NearestNeighborsSearches<DeviceType>,
ArborX::PredicatesTag>
{
using memory_space = typename DeviceType::memory_space;
static KOKKOS_FUNCTION std::size_t
size(NearestNeighborsSearches<DeviceType> const &pred)
{
return pred.points.extent(0);
}
static KOKKOS_FUNCTION auto
get(NearestNeighborsSearches<DeviceType> const &pred, std::size_t i)
{
return ArborX::nearest(pred.points(i), pred.k);
}
};

namespace bpo = boost::program_options;

template <class NO>
Expand Down Expand Up @@ -421,8 +375,8 @@ int main_(std::vector<std::string> const &args, MPI_Comm const comm)
knn->start();
distributed_tree.query(
ExecutionSpace{},
NearestNeighborsSearches<DeviceType>{random_queries, n_neighbors},
values, offsets);
ArborX::Experimental::nearest_k(random_queries, n_neighbors), values,
offsets);
knn->stop();

if (comm_rank == 0)
Expand Down Expand Up @@ -457,9 +411,11 @@ int main_(std::vector<std::string> const &args, MPI_Comm const comm)
auto radius = time_monitor.getNewTimer("radius");
MPI_Barrier(comm);
radius->start();
distributed_tree.query(ExecutionSpace{},
RadiusSearches<DeviceType>{random_queries, r},
values, offsets);
distributed_tree.query(
ExecutionSpace{},
ArborX::Experimental::intersect_geometries_with_radius(random_queries,
r),
values, offsets);
radius->stop();

if (comm_rank == 0)
Expand Down
29 changes: 5 additions & 24 deletions examples/molecular_dynamics/example_molecular_dynamics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,6 @@
#include <iostream>
#include <type_traits>

template <class MemorySpace>
struct Neighbors
{
Kokkos::View<ArborX::Point *, MemorySpace> _particles;
float _radius;
};

template <class MemorySpace>
struct ArborX::AccessTraits<Neighbors<MemorySpace>, ArborX::PredicatesTag>
{
using memory_space = MemorySpace;
using size_type = std::size_t;
static KOKKOS_FUNCTION size_type size(Neighbors<MemorySpace> const &x)
{
return x._particles.extent(0);
}
static KOKKOS_FUNCTION auto get(Neighbors<MemorySpace> const &x, size_type i)
{
return attach(intersects(Sphere{x._particles(i), x._radius}), (int)i);
}
};

struct ExcludeSelfCollision
{
template <class Predicate, class OutputFunctor>
Expand Down Expand Up @@ -119,8 +97,11 @@ int main(int argc, char *argv[])

Kokkos::View<int *, MemorySpace> indices("Example::indices", 0);
Kokkos::View<int *, MemorySpace> offsets("Example::offsets", 0);
index.query(execution_space, Neighbors<MemorySpace>{particles, r},
ExcludeSelfCollision{}, indices, offsets);
index.query(
execution_space,
ArborX::Experimental::attach_indices<int>(
ArborX::Experimental::intersect_geometries_with_radius(particles, r)),
ExcludeSelfCollision{}, indices, offsets);

Kokkos::View<float *[3], MemorySpace> forces(
Kokkos::view_alloc(execution_space, "Example::forces"), n);
Expand Down
31 changes: 3 additions & 28 deletions examples/raytracing/example_raytracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,6 @@ struct ArborX::AccessTraits<OrderedIntersectsBased::Rays<MemorySpace>,

namespace IntersectsBased
{
/*
* Storage for the rays and access traits used in the query/traverse.
*/
template <typename MemorySpace>
struct Rays
{
Kokkos::View<ArborX::Experimental::Ray *, MemorySpace> _rays;
};

/*
* IntersectedCell is a storage container for all intersections between rays and
Expand Down Expand Up @@ -177,25 +169,6 @@ struct AccumulateRaySphereIntersections
};
} // namespace IntersectsBased

template <typename MemorySpace>
struct ArborX::AccessTraits<IntersectsBased::Rays<MemorySpace>,
ArborX::PredicatesTag>
{
using memory_space = MemorySpace;
using size_type = std::size_t;

KOKKOS_FUNCTION
static size_type size(IntersectsBased::Rays<MemorySpace> const &rays)
{
return rays._rays.extent(0);
}
KOKKOS_FUNCTION
static auto get(IntersectsBased::Rays<MemorySpace> const &rays, size_type i)
{
return attach(intersects(rays._rays(i)), (int)i);
}
};

int main(int argc, char *argv[])
{
using ExecutionSpace = Kokkos::DefaultExecutionSpace;
Expand Down Expand Up @@ -336,7 +309,9 @@ int main(int argc, char *argv[])
0);
Kokkos::View<int *> offsets("Example::offsets", 0);
bvh.query(
exec_space, IntersectsBased::Rays<MemorySpace>{rays},
exec_space,
ArborX::Experimental::attach_indices<int>(
ArborX::Experimental::intersect_geometries(rays)),
IntersectsBased::AccumulateRaySphereIntersections<MemorySpace>{boxes},
values, offsets);

Expand Down
45 changes: 7 additions & 38 deletions src/ArborX_DBSCAN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <ArborX_HyperBox.hpp>
#include <ArborX_HyperSphere.hpp>
#include <ArborX_LinearBVH.hpp>
#include <ArborX_PredicateHelpers.hpp>
#include <ArborX_Sphere.hpp>

namespace ArborX
Expand Down Expand Up @@ -52,13 +53,6 @@ struct DBSCANCorePoints
}
};

template <typename Primitives>
struct PrimitivesWithRadius
{
Primitives _primitives;
float _r;
};

struct WithinRadiusGetter
{
float _r;
Expand Down Expand Up @@ -100,31 +94,6 @@ struct MixedBoxPrimitives

} // namespace Details

template <typename Primitives>
struct AccessTraits<Details::PrimitivesWithRadius<Primitives>, PredicatesTag>
{
using memory_space = typename Primitives::memory_space;
using Predicates = Details::PrimitivesWithRadius<Primitives>;

static KOKKOS_FUNCTION size_t size(Predicates const &w)
{
return w._primitives.size();
}
static KOKKOS_FUNCTION auto get(Predicates const &w, size_t i)
{
auto const &point = w._primitives(i);
constexpr int dim =
GeometryTraits::dimension_v<std::decay_t<decltype(point)>>;
// FIXME reinterpret_cast is dangerous here if access traits return user
// point structure (e.g., struct MyPoint { float y; float x; })
auto const &hyper_point =
reinterpret_cast<ExperimentalHyperGeometry::Point<dim> const &>(point);
return attach(
intersects(ExperimentalHyperGeometry::Sphere<dim>{hyper_point, w._r}),
(int)i);
}
};

template <typename Primitives, typename PermuteFilter>
struct AccessTraits<Details::PrimitivesWithRadiusReorderedAndFiltered<
Primitives, PermuteFilter>,
Expand Down Expand Up @@ -315,8 +284,8 @@ dbscan(ExecutionSpace const &exec_space, Primitives const &primitives,
}
else
{
auto const predicates =
Details::PrimitivesWithRadius<Points>{points, eps};
auto const predicates = ArborX::Experimental::attach_indices(
ArborX::Experimental::intersect_geometries_with_radius(points, eps));

// Determine core points
Kokkos::Profiling::pushRegion("ArborX::DBSCAN::clusters::num_neigh");
Expand Down Expand Up @@ -437,8 +406,8 @@ dbscan(ExecutionSpace const &exec_space, Primitives const &primitives,
// Perform the queries and build clusters through callback
using CorePoints = Details::CCSCorePoints;
Kokkos::Profiling::pushRegion("ArborX::DBSCAN::clusters::query");
auto const predicates =
Details::PrimitivesWithRadius<Points>{points, eps};
auto const predicates = Experimental::attach_indices(
Experimental::intersect_geometries_with_radius(points, eps));
bvh.query(exec_space, predicates,
Details::FDBSCANDenseBoxCallback<UnionFind, CorePoints, Points,
decltype(dense_cell_offsets),
Expand Down Expand Up @@ -478,8 +447,8 @@ dbscan(ExecutionSpace const &exec_space, Primitives const &primitives,

// Perform the queries and build clusters through callback
Kokkos::Profiling::pushRegion("ArborX::DBSCAN::clusters::query");
auto const predicates =
Details::PrimitivesWithRadius<Points>{points, eps};
auto const predicates = Experimental::attach_indices(
Experimental::intersect_geometries_with_radius(points, eps));
bvh.query(exec_space, predicates,
Details::FDBSCANDenseBoxCallback<UnionFind, CorePoints, Points,
decltype(dense_cell_offsets),
Expand Down
2 changes: 1 addition & 1 deletion src/details/ArborX_DetailsFDBSCANDenseBox.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ struct FDBSCANDenseBoxCallback
}
else
{
int j = _permute(_num_points_in_dense_cells + (k - _num_dense_cells));
auto j = _permute(_num_points_in_dense_cells + (k - _num_dense_cells));

// No need to check the distance here, as the fact that we are inside the
// callback guarantees that it is <= eps
Expand Down
28 changes: 0 additions & 28 deletions src/details/ArborX_DetailsMutualReachabilityDistance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,6 @@ struct MaxDistance
}
};

template <class Primitives>
struct NearestK
{
Primitives primitives;
int k; // including self-collisions
};

} // namespace Details

template <class Primitives>
struct AccessTraits<Details::NearestK<Primitives>, PredicatesTag>
{
using memory_space = typename Primitives::memory_space;
using size_type = typename memory_space::size_type;
static KOKKOS_FUNCTION size_type size(Details::NearestK<Primitives> const &x)
{
return x.primitives.size();
}
static KOKKOS_FUNCTION auto get(Details::NearestK<Primitives> const &x,
size_type i)
{
return attach(nearest(x.primitives(i), x.k), i);
}
};

namespace Details
{

template <class CoreDistances>
struct MutualReachability
{
Expand Down
9 changes: 6 additions & 3 deletions src/details/ArborX_MinimumSpanningTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <ArborX_DetailsTreeNodeLabeling.hpp>
#include <ArborX_DetailsWeightedEdge.hpp>
#include <ArborX_LinearBVH.hpp>
#include <ArborX_PredicateHelpers.hpp>

#include <Kokkos_Core.hpp>
#include <Kokkos_Profiling_ScopedRegion.hpp>
Expand Down Expand Up @@ -66,9 +67,11 @@ struct MinimumSpanningTree
Kokkos::Profiling::pushRegion("ArborX::MST::compute_core_distances");
Kokkos::View<float *, MemorySpace> core_distances(
"ArborX::MST::core_distances", n);
bvh.query(space, NearestK<Points>{points, k},
MaxDistance<Points, decltype(core_distances)>{points,
core_distances});
bvh.query(
space,
Experimental::attach_indices(Experimental::nearest_k(points, k)),
MaxDistance<Points, decltype(core_distances)>{points,
core_distances});
Kokkos::Profiling::popRegion();

MutualReachability<decltype(core_distances)> mutual_reachability{
Expand Down
8 changes: 6 additions & 2 deletions src/details/ArborX_PredicateHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ struct PrimitivesNearestK
{
private:
using Primitives = Details::AccessValues<UserPrimitives, PrimitivesTag>;
// FIXME:
// using Geometry = typename Primitives::value_type;
// static_assert(GeometryTraits::is_valid_geometry<Geometry>{});

public:
Primitives _primitives;
int _k; // not including self-collisions
int _k;
};

template <typename Primitives>
Expand All @@ -83,7 +86,8 @@ auto intersect_geometries_with_radius(Primitives const &primitives,
template <typename Primitives>
auto nearest_k(Primitives const &primitives, int k)
{
Details::check_valid_access_traits(PrimitivesTag{}, primitives);
Details::check_valid_access_traits(PrimitivesTag{}, primitives,
Details::DoNotCheckGetReturnType());
return PrimitivesNearestK<Primitives>{primitives, k};
}

Expand Down

0 comments on commit b13835c

Please sign in to comment.