Skip to content

Commit

Permalink
Merge pull request #978 from aprokop/access_traits_proliferation_pred…
Browse files Browse the repository at this point in the history
…icates

Cut down on the number of AccessTraits<Predicates, PredicatesTag>
  • Loading branch information
aprokop authored Dec 23, 2023
2 parents 8918428 + 9324554 commit 0a7cb09
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 162 deletions.
4 changes: 2 additions & 2 deletions .jenkins/continuous.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ pipeline {
}
}
}
stage('CUDA-11.0.3-NVCC') {
stage('CUDA-11.1.1-NVCC') {
agent {
dockerfile {
filename "Dockerfile"
dir "docker"
additionalBuildArgs '--build-arg BASE=nvidia/cuda:11.0.3-devel-ubuntu20.04 --build-arg KOKKOS_OPTIONS="-DCMAKE_CXX_EXTENSIONS=OFF -DKokkos_ENABLE_SERIAL=ON -DKokkos_ENABLE_OPENMP=ON -DKokkos_ENABLE_CUDA=ON -DKokkos_ARCH_VOLTA70=ON"'
additionalBuildArgs '--build-arg BASE=nvidia/cuda:11.1.1-devel-ubuntu20.04 --build-arg KOKKOS_OPTIONS="-DCMAKE_CXX_EXTENSIONS=OFF -DKokkos_ENABLE_SERIAL=ON -DKokkos_ENABLE_OPENMP=ON -DKokkos_ENABLE_CUDA=ON -DKokkos_ARCH_VOLTA70=ON"'
args '-v /tmp/ccache:/tmp/ccache --env NVIDIA_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES}'
label 'NVIDIA_Tesla_V100-PCIE-32GB && nvidia-docker'
}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/dbscan/ArborX_DBSCANVerification.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ bool verifyDBSCAN(ExecutionSpace exec_space, Primitives const &primitives,

static_assert(Kokkos::is_view<LabelsView>{});

using Points = Details::AccessValues<Primitives>;
using Points = Details::AccessValues<Primitives, PrimitivesTag>;
using MemorySpace = typename Points::memory_space;

static_assert(std::is_same<typename LabelsView::value_type, int>{});
Expand Down
32 changes: 18 additions & 14 deletions src/ArborX_BruteForce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,23 @@ class BruteForce
void query(ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Ignore = Ignore()) const;

template <typename ExecutionSpace, typename Predicates,
template <typename ExecutionSpace, typename UserPredicates,
typename CallbackOrView, typename View, typename... Args>
std::enable_if_t<Kokkos::is_view_v<std::decay_t<View>>>
query(ExecutionSpace const &space, Predicates const &predicates,
query(ExecutionSpace const &space, UserPredicates const &user_predicates,
CallbackOrView &&callback_or_view, View &&view, Args &&...args) const
{
Kokkos::Profiling::ScopedRegion guard("ArborX::BruteForce::query_crs");

Details::CrsGraphWrapperImpl::
check_valid_callback_if_first_argument_is_not_a_view<value_type>(
callback_or_view, predicates, view);
callback_or_view, user_predicates, view);

using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Tag = typename Predicates::value_type::Tag;

Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, predicates,
Tag{}, *this, space, Predicates{user_predicates},
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}
Expand Down Expand Up @@ -189,7 +189,7 @@ BruteForce<MemorySpace, Value, IndexableGetter, BoundingVolume>::BruteForce(
Details::check_valid_access_traits<UserValues>(
PrimitivesTag{}, user_values, Details::DoNotCheckGetReturnType());

using Values = Details::AccessValues<UserValues>;
using Values = Details::AccessValues<UserValues, PrimitivesTag>;
Values values{user_values};

static_assert(KokkosExt::is_accessible_from<typename Values::memory_space,
Expand All @@ -209,23 +209,27 @@ BruteForce<MemorySpace, Value, IndexableGetter, BoundingVolume>::BruteForce(

template <typename MemorySpace, typename Value, typename IndexableGetter,
typename BoundingVolume>
template <typename ExecutionSpace, typename Predicates, typename Callback,
template <typename ExecutionSpace, typename UserPredicates, typename Callback,
typename Ignore>
void BruteForce<MemorySpace, Value, IndexableGetter, BoundingVolume>::query(
ExecutionSpace const &space, Predicates const &predicates,
ExecutionSpace const &space, UserPredicates const &user_predicates,
Callback const &callback, Ignore) const
{
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);
Details::check_valid_access_traits(PredicatesTag{}, predicates);
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
Details::check_valid_access_traits(PredicatesTag{}, user_predicates);
Details::check_valid_callback<value_type>(callback, user_predicates);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Predicates::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
using Tag = typename Details::AccessTraitsHelper<Access>::tag;

Predicates predicates{user_predicates};

using Tag = typename Predicates::value_type::Tag;
static_assert(std::is_same<Tag, Details::SpatialPredicateTag>{},
"nearest query not implemented yet");
Details::check_valid_callback<Value>(callback, predicates);

Kokkos::Profiling::pushRegion("ArborX::BruteForce::query::spatial");

Expand Down
2 changes: 1 addition & 1 deletion src/ArborX_DBSCAN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ dbscan(ExecutionSpace const &exec_space, Primitives const &primitives,
{
Kokkos::Profiling::pushRegion("ArborX::DBSCAN");

using Points = Details::AccessValues<Primitives>;
using Points = Details::AccessValues<Primitives, PrimitivesTag>;
using MemorySpace = typename Points::memory_space;

static_assert(
Expand Down
20 changes: 15 additions & 5 deletions src/ArborX_DistributedTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef ARBORX_DISTRIBUTED_TREE_HPP
#define ARBORX_DISTRIBUTED_TREE_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_DetailsDistributedTreeImpl.hpp>
#include <ArborX_DetailsUtils.hpp> // accumulate
Expand Down Expand Up @@ -86,13 +87,22 @@ class DistributedTree
* - \c distances Computed distances (optional and only for nearest
* predicates).
*/
template <typename ExecutionSpace, typename Predicates, typename... Args>
void query(ExecutionSpace const &space, Predicates const &predicates,
template <typename ExecutionSpace, typename UserPredicates, typename... Args>
void query(ExecutionSpace const &space, UserPredicates const &user_predicates,
Args &&...args) const
{
static_assert(Kokkos::is_execution_space<ExecutionSpace>::value);
using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
static_assert(
KokkosExt::is_accessible_from<typename Predicates::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");

Predicates predicates{user_predicates};

using Tag = typename Predicates::value_type::Tag;
using DeviceType = Kokkos::Device<ExecutionSpace, MemorySpace>;
Details::DistributedTreeImpl<DeviceType>::queryDispatch(
Tag{}, *this, space, predicates, std::forward<Args>(args)...);
Expand Down
30 changes: 16 additions & 14 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,23 @@ class BoundingVolumeHierarchy
Experimental::TraversalPolicy const &policy =
Experimental::TraversalPolicy()) const;

template <typename ExecutionSpace, typename Predicates,
template <typename ExecutionSpace, typename UserPredicates,
typename CallbackOrView, typename View, typename... Args>
std::enable_if_t<Kokkos::is_view_v<std::decay_t<View>>>
query(ExecutionSpace const &space, Predicates const &predicates,
query(ExecutionSpace const &space, UserPredicates const &user_predicates,
CallbackOrView &&callback_or_view, View &&view, Args &&...args) const
{
Kokkos::Profiling::ScopedRegion guard("ArborX::BVH::query_crs");

Details::CrsGraphWrapperImpl::
check_valid_callback_if_first_argument_is_not_a_view<value_type>(
callback_or_view, predicates, view);
callback_or_view, user_predicates, view);

using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Tag = typename Predicates::value_type::Tag;

Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, predicates,
Tag{}, *this, space, Predicates{user_predicates},
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}
Expand Down Expand Up @@ -265,7 +265,7 @@ BoundingVolumeHierarchy<MemorySpace, Value, IndexableGetter, BoundingVolume>::
Details::check_valid_access_traits<UserValues>(
PrimitivesTag{}, user_values, Details::DoNotCheckGetReturnType());

using Values = Details::AccessValues<UserValues>;
using Values = Details::AccessValues<UserValues, PrimitivesTag>;
Values values{user_values};

static_assert(KokkosExt::is_accessible_from<typename Values::memory_space,
Expand Down Expand Up @@ -336,24 +336,26 @@ BoundingVolumeHierarchy<MemorySpace, Value, IndexableGetter, BoundingVolume>::

template <typename MemorySpace, typename Value, typename IndexableGetter,
typename BoundingVolume>
template <typename ExecutionSpace, typename Predicates, typename Callback>
template <typename ExecutionSpace, typename UserPredicates, typename Callback>
void BoundingVolumeHierarchy<
MemorySpace, Value, IndexableGetter,
BoundingVolume>::query(ExecutionSpace const &space,
Predicates const &predicates,
UserPredicates const &user_predicates,
Callback const &callback,
Experimental::TraversalPolicy const &policy) const
{
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);
Details::check_valid_access_traits(PredicatesTag{}, predicates);
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
Details::check_valid_access_traits(PredicatesTag{}, user_predicates);
Details::check_valid_callback<value_type>(callback, user_predicates);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Predicates::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
Details::check_valid_callback<value_type>(callback, predicates);
Predicates predicates{user_predicates};

using Tag = typename Details::AccessTraitsHelper<Access>::tag;
using Tag = typename Predicates::value_type::Tag;
std::string profiling_prefix = "ArborX::BVH::query::";
if constexpr (std::is_same_v<Tag, Details::SpatialPredicateTag>)
{
Expand Down
10 changes: 5 additions & 5 deletions src/details/ArborX_AccessTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,11 @@ void check_valid_access_traits(PrimitivesTag, Primitives const &,
}
}

template <typename Values>
template <typename Values, typename Tag>
class AccessValues
{
private:
using Access = AccessTraits<Values, PrimitivesTag>;
using Access = AccessTraits<Values, Tag>;

public:
Values _values;
Expand All @@ -221,10 +221,10 @@ class AccessValues

} // namespace Details

template <typename Values>
struct AccessTraits<Details::AccessValues<Values>, PrimitivesTag>
template <typename Values, typename Tag>
struct AccessTraits<Details::AccessValues<Values, Tag>, Tag>
{
using AccessValues = Details::AccessValues<Values>;
using AccessValues = Details::AccessValues<Values, Tag>;

using memory_space = typename AccessValues::memory_space;

Expand Down
26 changes: 9 additions & 17 deletions src/details/ArborX_DetailsBatchedQueries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#ifndef ARBORX_DETAILS_BATCHED_QUERIES_HPP
#define ARBORX_DETAILS_BATCHED_QUERIES_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_DetailsAlgorithms.hpp> // returnCentroid, translateAndScale
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
Expand Down Expand Up @@ -53,11 +52,10 @@ struct BatchedQueries
Box const &scene_bounding_box,
Predicates const &predicates)
{
using Access = AccessTraits<Predicates, PredicatesTag>;
auto const n_queries = Access::size(predicates);
auto const n_queries = predicates.size();

using Point = std::decay_t<decltype(returnCentroid(
getGeometry(Access::get(predicates, 0))))>;
using Point =
std::decay_t<decltype(returnCentroid(getGeometry(predicates(0))))>;
using LinearOrderingValueType =
Kokkos::detected_t<SpaceFillingCurveProjectionArchetypeExpression,
SpaceFillingCurve, Box, Point>;
Expand All @@ -69,9 +67,8 @@ struct BatchedQueries
"ArborX::BatchedQueries::project_predicates_onto_space_filling_curve",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
KOKKOS_LAMBDA(int i) {
linear_ordering_indices(i) =
curve(scene_bounding_box,
returnCentroid(getGeometry(Access::get(predicates, i))));
linear_ordering_indices(i) = curve(
scene_bounding_box, returnCentroid(getGeometry(predicates(i))));
});

return sortObjects(space, linear_ordering_indices);
Expand All @@ -85,24 +82,19 @@ struct BatchedQueries
applyPermutation(ExecutionSpace const &space,
Kokkos::View<unsigned int const *, DeviceType> permute,
Predicates const &v)
-> Kokkos::View<typename AccessTraitsHelper<
AccessTraits<Predicates, PredicatesTag>>::type *,
DeviceType>
-> Kokkos::View<typename Predicates::value_type *, DeviceType>
{
using Access = AccessTraits<Predicates, PredicatesTag>;
auto const n = Access::size(v);
auto const n = v.size();
ARBORX_ASSERT(permute.extent(0) == n);

using T = std::decay_t<decltype(Access::get(
std::declval<Predicates const &>(), std::declval<int>()))>;
Kokkos::View<T *, DeviceType> w(
Kokkos::View<typename Predicates::value_type *, DeviceType> w(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::permuted_predicates"),
n);
Kokkos::parallel_for(
"ArborX::BatchedQueries::permute_entries",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n),
KOKKOS_LAMBDA(int i) { w(i) = Access::get(v, permute(i)); });
KOKKOS_LAMBDA(int i) { w(i) = v(permute(i)); });

return w;
}
Expand Down
9 changes: 3 additions & 6 deletions src/details/ArborX_DetailsBruteForceImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#ifndef ARBORX_DETAILS_BRUTE_FORCE_IMPL_HPP
#define ARBORX_DETAILS_BRUTE_FORCE_IMPL_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_DetailsAlgorithms.hpp> // expand
#include <ArborX_Exception.hpp>

Expand Down Expand Up @@ -53,12 +52,11 @@ struct BruteForceImpl
Callback const &callback)
{
using TeamPolicy = Kokkos::TeamPolicy<ExecutionSpace>;
using AccessPredicates = AccessTraits<Predicates, PredicatesTag>;
using PredicateType = typename AccessTraitsHelper<AccessPredicates>::type;
using PredicateType = typename Predicates::value_type;
using IndexableType = std::decay_t<decltype(indexables(0))>;

int const n_indexables = values.size();
int const n_predicates = AccessPredicates::size(predicates);
int const n_predicates = predicates.size();
int max_scratch_size = TeamPolicy::scratch_size_max(0);
// half of the scratch memory used by predicates and half for indexables
int const predicates_per_team =
Expand Down Expand Up @@ -110,8 +108,7 @@ struct BruteForceImpl
Kokkos::parallel_for(
Kokkos::TeamVectorRange(teamMember, predicates_in_this_team),
[&](const int q) {
scratch_predicates(q) =
AccessPredicates::get(predicates, predicate_start + q);
scratch_predicates(q) = predicates(predicate_start + q);
});
Kokkos::parallel_for(
Kokkos::TeamVectorRange(teamMember, indexables_in_this_team),
Expand Down
Loading

0 comments on commit 0a7cb09

Please sign in to comment.