Skip to content

Commit

Permalink
Merge pull request #1036 from aprokop/attach_indices
Browse files Browse the repository at this point in the history
Make attach_indices work for both primitives and predicates
  • Loading branch information
aprokop committed Feb 13, 2024
2 parents ea784f6 + d1de277 commit ef95952
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 108 deletions.
32 changes: 6 additions & 26 deletions benchmarks/bvh_driver/benchmark_registration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,6 @@ makeNearestQueries(int n_values, int n_queries, int n_neighbors,
return queries;
}

template <typename Queries>
struct QueriesWithIndex
{
Queries _queries;
};

template <typename Queries>
struct ArborX::AccessTraits<QueriesWithIndex<Queries>, ArborX::PredicatesTag>
{
using memory_space = typename Queries::memory_space;
static KOKKOS_FUNCTION size_t size(QueriesWithIndex<Queries> const &q)
{
return q._queries.extent(0);
}
static KOKKOS_FUNCTION auto get(QueriesWithIndex<Queries> const &q, size_t i)
{
return attach(q._queries(i), (int)i);
}
};

template <typename DeviceType>
struct CountCallback
{
Expand Down Expand Up @@ -276,10 +256,9 @@ void BM_radius_callback_search(benchmark::State &state, Spec const &spec)
TreeType index(
ExecutionSpace{},
constructPoints<DeviceType>(spec.n_values, spec.source_point_cloud_type));
auto const queries_no_index = makeSpatialQueries<DeviceType>(
auto const queries = makeSpatialQueries<DeviceType>(
spec.n_values, spec.n_queries, spec.n_neighbors,
spec.target_point_cloud_type);
QueriesWithIndex<decltype(queries_no_index)> queries{queries_no_index};

for (auto _ : state)
{
Expand All @@ -290,7 +269,8 @@ void BM_radius_callback_search(benchmark::State &state, Spec const &spec)
exec_space.fence();
auto const start = std::chrono::high_resolution_clock::now();

index.query(exec_space, queries, callback,
index.query(exec_space, ArborX::Experimental::attach_indices<int>(queries),
callback,
ArborX::Experimental::TraversalPolicy().setPredicateSorting(
spec.sort_predicates));

Expand Down Expand Up @@ -348,10 +328,9 @@ void BM_knn_callback_search(benchmark::State &state, Spec const &spec)

TreeType index(exec_space, constructPoints<DeviceType>(
spec.n_values, spec.source_point_cloud_type));
auto const queries_no_index = makeNearestQueries<DeviceType>(
auto const queries = makeNearestQueries<DeviceType>(
spec.n_values, spec.n_queries, spec.n_neighbors,
spec.target_point_cloud_type);
QueriesWithIndex<decltype(queries_no_index)> queries{queries_no_index};

for (auto _ : state)
{
Expand All @@ -362,7 +341,8 @@ void BM_knn_callback_search(benchmark::State &state, Spec const &spec)
exec_space.fence();
auto const start = std::chrono::high_resolution_clock::now();

index.query(exec_space, queries, callback,
index.query(exec_space, ArborX::Experimental::attach_indices<int>(queries),
callback,
ArborX::Experimental::TraversalPolicy().setPredicateSorting(
spec.sort_predicates));

Expand Down
1 change: 1 addition & 0 deletions src/ArborX_BruteForce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#define ARBORX_BRUTE_FORCE_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_AttachIndices.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_CrsGraphWrapper.hpp>
#include <ArborX_DetailsBruteForceImpl.hpp>
Expand Down
1 change: 1 addition & 0 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#define ARBORX_LINEAR_BVH_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_AttachIndices.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_Callbacks.hpp>
#include <ArborX_CrsGraphWrapper.hpp>
Expand Down
85 changes: 85 additions & 0 deletions src/details/ArborX_AttachIndices.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
* distributed under a BSD 3-clause license. For the licensing terms see *
* the LICENSE file in the top-level directory. *
* *
* SPDX-License-Identifier: BSD-3-Clause *
****************************************************************************/
#ifndef ARBORX_DETAILS_ATTACH_INDICES_HPP
#define ARBORX_DETAILS_ATTACH_INDICES_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_PairValueIndex.hpp>
#include <ArborX_Predicates.hpp>

namespace ArborX
{

namespace Experimental
{
template <typename Values, typename Index>
struct AttachIndices
{
Values _values;
};

// Make sure the default Index matches the default in PairValueIndex
template <typename Index = typename PairValueIndex<int>::index_type,
typename Values = void>
auto attach_indices(Values const &values)
{
return AttachIndices<Values, Index>{values};
}
} // namespace Experimental

} // namespace ArborX

template <typename Values, typename Index>
struct ArborX::AccessTraits<ArborX::Experimental::AttachIndices<Values, Index>,
ArborX::PrimitivesTag>
{
private:
using Self = ArborX::Experimental::AttachIndices<Values, Index>;
using Access = AccessTraits<Values, ArborX::PrimitivesTag>;
using value_type = ArborX::PairValueIndex<
std::decay_t<Kokkos::detected_t<
ArborX::Details::AccessTraitsGetArchetypeExpression, Access, Values>>,
Index>;

public:
using memory_space = typename Access::memory_space;

KOKKOS_FUNCTION static auto size(Self const &self)
{
return Access::size(self._values);
}
KOKKOS_FUNCTION static auto get(Self const &self, int i)
{
return value_type{Access::get(self._values, i), Index(i)};
}
};
template <typename Values, typename Index>
struct ArborX::AccessTraits<ArborX::Experimental::AttachIndices<Values, Index>,
ArborX::PredicatesTag>
{
private:
using Self = ArborX::Experimental::AttachIndices<Values, Index>;
using Access = AccessTraits<Values, ArborX::PredicatesTag>;

public:
using memory_space = typename Access::memory_space;

KOKKOS_FUNCTION static auto size(Self const &self)
{
return Access::size(self._values);
}
KOKKOS_FUNCTION static auto get(Self const &self, int i)
{
return attach(Access::get(self._values, i), Index(i));
}
};

#endif
49 changes: 0 additions & 49 deletions src/details/ArborX_PairValueIndex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,55 +31,6 @@ struct PairValueIndex
Index index;
};

namespace Experimental
{
template <typename Values, typename Index>
class AttachIndices
{
private:
using Data = Details::AccessValues<Values, PrimitivesTag>;

public:
Data _data;

using memory_space = typename Data::memory_space;
using value_type = PairValueIndex<typename Data::value_type, Index>;

AttachIndices(Values const &values)
: _data{values}
{}

KOKKOS_FUNCTION
auto operator()(int i) const { return value_type{_data(i), Index(i)}; }

KOKKOS_FUNCTION
auto size() const { return _data.size(); }
};

// Make sure the default Index matches the default in PairValueIndex
template <typename Index = typename PairValueIndex<int>::index_type,
typename Values = void>
auto attach_indices(Values const &values)
{
return AttachIndices<Values, Index>{values};
}
} // namespace Experimental

} // namespace ArborX

template <typename Values, typename Index>
struct ArborX::AccessTraits<ArborX::Experimental::AttachIndices<Values, Index>,
ArborX::PrimitivesTag>
{
using Self = ArborX::Experimental::AttachIndices<Values, Index>;

using memory_space = typename Self::memory_space;

KOKKOS_FUNCTION static auto size(Self const &values) { return values.size(); }
KOKKOS_FUNCTION static decltype(auto) get(Self const &values, int i)
{
return values(i);
}
};

#endif
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_executable(ArborX_Test_CompileOnly.exe
target_link_libraries(ArborX_Test_CompileOnly.exe PRIVATE ArborX)

add_executable(ArborX_Test_DetailsUtils.exe
tstAttachIndices.cpp
tstDetailsUtils.cpp
tstDetailsKokkosExtStdAlgorithms.cpp
tstDetailsKokkosExtMinMaxReduce.cpp
Expand Down
47 changes: 47 additions & 0 deletions test/tstAttachIndices.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
* distributed under a BSD 3-clause license. For the licensing terms see *
* the LICENSE file in the top-level directory. *
* *
* SPDX-License-Identifier: BSD-3-Clause *
****************************************************************************/

#include <ArborX_AccessTraits.hpp>
#include <ArborX_AttachIndices.hpp>

#include <boost/test/unit_test.hpp>

BOOST_AUTO_TEST_SUITE(AttachIndices)

BOOST_AUTO_TEST_CASE(attach_indices_to_primitives)
{
using ArborX::Details::AccessValues;
using ArborX::Experimental::attach_indices;

Kokkos::View<ArborX::Point *, Kokkos::HostSpace> p("Testing::p", 10);
auto p_with_indices = attach_indices(p);
AccessValues<decltype(p_with_indices), ArborX::PrimitivesTag> p_values{
p_with_indices};
static_assert(std::is_same_v<decltype(p_values(0).index), unsigned>);
BOOST_TEST(p_values(0).index == 0);
BOOST_TEST(p_values(9).index == 9);
}

BOOST_AUTO_TEST_CASE(attach_indices_to_predicates)
{
using ArborX::Details::AccessValues;
using ArborX::Experimental::attach_indices;

using IntersectsPredicate = decltype(ArborX::intersects(ArborX::Point{}));
Kokkos::View<IntersectsPredicate *, Kokkos::HostSpace> q("Testing::q", 10);
auto q_with_indices = attach_indices<long>(q);
AccessValues<decltype(q_with_indices), ArborX::PredicatesTag> q_values{
q_with_indices};
BOOST_TEST(ArborX::getData(q_values(0)) == 0);
BOOST_TEST(ArborX::getData(q_values(9)) == 9);
}

BOOST_AUTO_TEST_SUITE_END()
57 changes: 48 additions & 9 deletions test/tstCompileOnlyAccessTraits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
****************************************************************************/

#include <ArborX_AccessTraits.hpp>
#include <ArborX_AttachIndices.hpp>
#include <ArborX_HyperPoint.hpp>
#include <ArborX_Point.hpp>

Expand Down Expand Up @@ -58,17 +59,55 @@ struct ArborX::Traits::Access<LegacyAccessTraits, Tag>
static Point get(LegacyAccessTraits, int) { return {}; }
};

template <class V, class Tag>
using deduce_type_t =
decltype(ArborX::AccessTraits<V, Tag>::get(std::declval<V>(), 0));

void test_access_traits_compile_only()
{
Kokkos::View<ArborX::Point *> p;
Kokkos::View<float **> v;
check_valid_access_traits(PrimitivesTag{}, p);
check_valid_access_traits(PrimitivesTag{}, v);

auto p_with_indices = ArborX::Experimental::attach_indices(p);
check_valid_access_traits(PrimitivesTag{}, p_with_indices,
ArborX::Details::DoNotCheckGetReturnType());
static_assert(
std::is_same_v<deduce_type_t<decltype(p_with_indices), PrimitivesTag>,
ArborX::PairValueIndex<ArborX::Point, unsigned>>);

auto p_with_indices_long = ArborX::Experimental::attach_indices<long>(p);
static_assert(std::is_same_v<
deduce_type_t<decltype(p_with_indices_long), PrimitivesTag>,
ArborX::PairValueIndex<ArborX::Point, long>>);

using NearestPredicate = decltype(ArborX::nearest(ArborX::Point{}));
Kokkos::View<NearestPredicate *> q;
check_valid_access_traits(PredicatesTag{}, q);

auto q_with_indices = ArborX::Experimental::attach_indices<long>(q);
check_valid_access_traits(PredicatesTag{}, q_with_indices);
using predicate = deduce_type_t<decltype(q_with_indices), PredicatesTag>;
static_assert(
std::is_same_v<
std::decay_t<decltype(ArborX::getData(std::declval<predicate>()))>,
long>);

struct CustomIndex
{
char index;
CustomIndex(int i) { index = i; }
};
auto q_with_custom_indices =
ArborX::Experimental::attach_indices<CustomIndex>(q);
check_valid_access_traits(PredicatesTag{}, q_with_custom_indices);
using predicate_custom =
deduce_type_t<decltype(q_with_custom_indices), PredicatesTag>;
static_assert(std::is_same_v<std::decay_t<decltype(ArborX::getData(
std::declval<predicate_custom>()))>,
CustomIndex>);

// Uncomment to see error messages

// check_valid_access_traits(PrimitivesTag{}, NoAccessTraitsSpecialization{});
Expand All @@ -82,21 +121,21 @@ void test_access_traits_compile_only()
// check_valid_access_traits(PrimitivesTag{}, LegacyAccessTraits{});
}

template <class V>
using deduce_point_t =
decltype(ArborX::AccessTraits<V, ArborX::PrimitivesTag>::get(
std::declval<V>(), 0));

void test_deduce_point_type_from_view()
{
using GoodOlePoint = ArborX::Point;
using ArborX::PrimitivesTag;
using ArborX::ExperimentalHyperGeometry::Point;
static_assert(
std::is_same_v<deduce_point_t<Kokkos::View<float **>>, GoodOlePoint>);
std::is_same_v<deduce_type_t<Kokkos::View<float **>, PrimitivesTag>,
GoodOlePoint>);
static_assert(
std::is_same_v<deduce_point_t<Kokkos::View<float *[3]>>, Point<3>>);
std::is_same_v<deduce_type_t<Kokkos::View<float *[3]>, PrimitivesTag>,
Point<3>>);
static_assert(
std::is_same_v<deduce_point_t<Kokkos::View<float *[2]>>, Point<2>>);
std::is_same_v<deduce_type_t<Kokkos::View<float *[2]>, PrimitivesTag>,
Point<2>>);
static_assert(
std::is_same_v<deduce_point_t<Kokkos::View<float *[5]>>, Point<5>>);
std::is_same_v<deduce_type_t<Kokkos::View<float *[5]>, PrimitivesTag>,
Point<5>>);
}
Loading

0 comments on commit ef95952

Please sign in to comment.