Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AccessTraits instead of hardcoded View in DBSCAN #509

Merged
merged 5 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/dbscan/ArborX_DBSCANVerification.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ bool verifyDBSCAN(ExecutionSpace exec_space, Primitives const &primitives,

static_assert(Kokkos::is_view<LabelsView>{}, "");

using MemorySpace = typename Primitives::memory_space;
using Access = AccessTraits<Primitives, PrimitivesTag>;
using MemorySpace = typename Access::memory_space;

static_assert(std::is_same<typename LabelsView::value_type, int>{}, "");
static_assert(std::is_same<typename LabelsView::memory_space, MemorySpace>{},
Expand Down
43 changes: 29 additions & 14 deletions src/ArborX_DBSCAN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef ARBORX_DBSCAN_HPP
#define ARBORX_DBSCAN_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_DetailsDBSCANCallback.hpp>
#include <ArborX_DetailsSortUtils.hpp>
#include <ArborX_DetailsUtils.hpp>
Expand All @@ -22,28 +23,36 @@
namespace ArborX
{

template <typename View>
template <typename Primitives>
struct PrimitivesWithRadius
{
View _M_view;
Primitives _primitives;
double _r;
};

template <typename View>
auto buildPredicates(View v, double r)
template <typename Primitives>
auto buildPredicates(Primitives const &v, double r)
{
return PrimitivesWithRadius<View>{v, r};
return PrimitivesWithRadius<Primitives>{v, r};
}

template <typename View>
struct AccessTraits<PrimitivesWithRadius<View>, PredicatesTag>
template <typename Primitives>
struct AccessTraits<PrimitivesWithRadius<Primitives>, PredicatesTag>
{
using memory_space = typename View::memory_space;
using Predicates = PrimitivesWithRadius<View>;
static size_t size(Predicates const &w) { return w._M_view.extent(0); }
using PrimitivesAccess = AccessTraits<Primitives, PrimitivesTag>;

using memory_space = typename PrimitivesAccess::memory_space;
using Predicates = PrimitivesWithRadius<Primitives>;

static size_t size(Predicates const &w)
{
return PrimitivesAccess::size(w._primitives);
}
static KOKKOS_FUNCTION auto get(Predicates const &w, size_t i)
{
return attach(intersects(Sphere{w._M_view(i), w._r}), (int)i);
return attach(
intersects(Sphere{PrimitivesAccess::get(w._primitives, i), w._r}),
(int)i);
}
};

Expand Down Expand Up @@ -103,14 +112,20 @@ struct Parameters
} // namespace DBSCAN

template <typename ExecutionSpace, typename Primitives>
Kokkos::View<int *, typename Primitives::memory_space>
Kokkos::View<int *,
typename AccessTraits<Primitives, PrimitivesTag>::memory_space>
dbscan(ExecutionSpace const &exec_space, Primitives const &primitives,
float eps, int core_min_size,
DBSCAN::Parameters const &parameters = DBSCAN::Parameters())
{
Kokkos::Profiling::pushRegion("ArborX::dbscan");

using MemorySpace = typename Primitives::memory_space;
using Access = AccessTraits<Primitives, PrimitivesTag>;
using MemorySpace = typename Access::memory_space;

static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value,
"Primitives must be accessible from the execution space");

ARBORX_ASSERT(eps > 0);
ARBORX_ASSERT(core_min_size >= 2);
Expand All @@ -134,7 +149,7 @@ dbscan(ExecutionSpace const &exec_space, Primitives const &primitives,

auto const predicates = buildPredicates(primitives, eps);

int const n = primitives.extent_int(0);
auto const n = Access::size(primitives);

// Build the tree
timer_start(timer);
Expand Down
30 changes: 30 additions & 0 deletions test/tstDBSCAN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,27 @@
#include "BoostTest_CUDA_clang_workarounds.hpp"
#include <boost/test/unit_test.hpp>

template <typename View>
struct HiddenView
{
View _view;
};
template <typename View>
struct ArborX::AccessTraits<HiddenView<View>, ArborX::PrimitivesTag>
{
using Data = HiddenView<View>;
static KOKKOS_FUNCTION std::size_t size(Data const &data)
{
return data._view.extent(0);
}
static KOKKOS_FUNCTION typename View::value_type const &get(Data const &data,
std::size_t i)
{
return data._view(i);
}
using memory_space = typename View::memory_space;
};

BOOST_AUTO_TEST_SUITE(DBSCAN)

template <typename DeviceType, typename T>
Expand Down Expand Up @@ -125,6 +146,15 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(dbscan, DeviceType, ARBORX_DEVICE_TYPES)
dbscan(space, points, r - 0.1, 2)));
BOOST_TEST(verifyDBSCAN(space, points, r, 2, dbscan(space, points, r, 2)));
BOOST_TEST(verifyDBSCAN(space, points, r, 3, dbscan(space, points, r, 3)));

// Test non-View primitives
HiddenView<decltype(points)> hidden_points{points};
BOOST_TEST(verifyDBSCAN(space, hidden_points, r - 0.1, 2,
dbscan(space, hidden_points, r - 0.1, 2)));
BOOST_TEST(verifyDBSCAN(space, hidden_points, r, 2,
dbscan(space, hidden_points, r, 2)));
BOOST_TEST(verifyDBSCAN(space, hidden_points, r, 3,
dbscan(space, hidden_points, r, 3)));
}

{
Expand Down