diff --git a/examples/dbscan/ArborX_DBSCANVerification.hpp b/examples/dbscan/ArborX_DBSCANVerification.hpp index bedaf2bcc..ee4fdf8c5 100644 --- a/examples/dbscan/ArborX_DBSCANVerification.hpp +++ b/examples/dbscan/ArborX_DBSCANVerification.hpp @@ -261,7 +261,8 @@ bool verifyDBSCAN(ExecutionSpace exec_space, Primitives const &primitives, static_assert(Kokkos::is_view{}, ""); - using MemorySpace = typename Primitives::memory_space; + using Access = AccessTraits; + using MemorySpace = typename Access::memory_space; static_assert(std::is_same{}, ""); static_assert(std::is_same{}, diff --git a/src/ArborX_DBSCAN.hpp b/src/ArborX_DBSCAN.hpp index 1ae83ceb7..863f91acd 100644 --- a/src/ArborX_DBSCAN.hpp +++ b/src/ArborX_DBSCAN.hpp @@ -12,6 +12,7 @@ #ifndef ARBORX_DBSCAN_HPP #define ARBORX_DBSCAN_HPP +#include #include #include #include @@ -22,28 +23,36 @@ namespace ArborX { -template +template struct PrimitivesWithRadius { - View _M_view; + Primitives _primitives; double _r; }; -template -auto buildPredicates(View v, double r) +template +auto buildPredicates(Primitives const &v, double r) { - return PrimitivesWithRadius{v, r}; + return PrimitivesWithRadius{v, r}; } -template -struct AccessTraits, PredicatesTag> +template +struct AccessTraits, PredicatesTag> { - using memory_space = typename View::memory_space; - using Predicates = PrimitivesWithRadius; - static size_t size(Predicates const &w) { return w._M_view.extent(0); } + using PrimitivesAccess = AccessTraits; + + using memory_space = typename PrimitivesAccess::memory_space; + using Predicates = PrimitivesWithRadius; + + static size_t size(Predicates const &w) + { + return PrimitivesAccess::size(w._primitives); + } static KOKKOS_FUNCTION auto get(Predicates const &w, size_t i) { - return attach(intersects(Sphere{w._M_view(i), w._r}), (int)i); + return attach( + intersects(Sphere{PrimitivesAccess::get(w._primitives, i), w._r}), + (int)i); } }; @@ -103,14 +112,20 @@ struct Parameters } // namespace DBSCAN template -Kokkos::View +Kokkos::View::memory_space> dbscan(ExecutionSpace const &exec_space, Primitives const &primitives, float eps, int core_min_size, DBSCAN::Parameters const ¶meters = DBSCAN::Parameters()) { Kokkos::Profiling::pushRegion("ArborX::dbscan"); - using MemorySpace = typename Primitives::memory_space; + using Access = AccessTraits; + using MemorySpace = typename Access::memory_space; + + static_assert( + KokkosExt::is_accessible_from::value, + "Primitives must be accessible from the execution space"); ARBORX_ASSERT(eps > 0); ARBORX_ASSERT(core_min_size >= 2); @@ -134,7 +149,7 @@ dbscan(ExecutionSpace const &exec_space, Primitives const &primitives, auto const predicates = buildPredicates(primitives, eps); - int const n = primitives.extent_int(0); + auto const n = Access::size(primitives); // Build the tree timer_start(timer); diff --git a/test/tstDBSCAN.cpp b/test/tstDBSCAN.cpp index 207e3e17b..cd1d60222 100644 --- a/test/tstDBSCAN.cpp +++ b/test/tstDBSCAN.cpp @@ -15,6 +15,27 @@ #include "BoostTest_CUDA_clang_workarounds.hpp" #include +template +struct HiddenView +{ + View _view; +}; +template +struct ArborX::AccessTraits, ArborX::PrimitivesTag> +{ + using Data = HiddenView; + static KOKKOS_FUNCTION std::size_t size(Data const &data) + { + return data._view.extent(0); + } + static KOKKOS_FUNCTION typename View::value_type const &get(Data const &data, + std::size_t i) + { + return data._view(i); + } + using memory_space = typename View::memory_space; +}; + BOOST_AUTO_TEST_SUITE(DBSCAN) template @@ -125,6 +146,15 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(dbscan, DeviceType, ARBORX_DEVICE_TYPES) dbscan(space, points, r - 0.1, 2))); BOOST_TEST(verifyDBSCAN(space, points, r, 2, dbscan(space, points, r, 2))); BOOST_TEST(verifyDBSCAN(space, points, r, 3, dbscan(space, points, r, 3))); + + // Test non-View primitives + HiddenView hidden_points{points}; + BOOST_TEST(verifyDBSCAN(space, hidden_points, r - 0.1, 2, + dbscan(space, hidden_points, r - 0.1, 2))); + BOOST_TEST(verifyDBSCAN(space, hidden_points, r, 2, + dbscan(space, hidden_points, r, 2))); + BOOST_TEST(verifyDBSCAN(space, hidden_points, r, 3, + dbscan(space, hidden_points, r, 3))); } {