Skip to content

Commit

Permalink
Merge pull request #425 from aprokop/ext_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Dec 28, 2020
2 parents d288e72 + cc01971 commit 7040316
Show file tree
Hide file tree
Showing 19 changed files with 606 additions and 528 deletions.
15 changes: 8 additions & 7 deletions benchmarks/bvh_driver/bvh_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
****************************************************************************/

#include <ArborX_BoostRTreeHelpers.hpp>
#include <ArborX_CrsGraphWrapper.hpp>
#include <ArborX_LinearBVH.hpp>
#include <ArborX_Version.hpp>

Expand Down Expand Up @@ -225,9 +226,9 @@ void BM_knn_search(benchmark::State &state, Spec const &spec)
Kokkos::View<int *, DeviceType> offset("offset", 0);
Kokkos::View<int *, DeviceType> indices("indices", 0);
auto const start = std::chrono::high_resolution_clock::now();
index.query(ExecutionSpace{}, queries, indices, offset,
ArborX::Experimental::TraversalPolicy().setPredicateSorting(
spec.sort_predicates));
ArborX::query(index, ExecutionSpace{}, queries, indices, offset,
ArborX::Experimental::TraversalPolicy().setPredicateSorting(
spec.sort_predicates));
auto const end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed_seconds = end - start;
state.SetIterationTime(elapsed_seconds.count());
Expand Down Expand Up @@ -295,10 +296,10 @@ void BM_radius_search(benchmark::State &state, Spec const &spec)
Kokkos::View<int *, DeviceType> offset("offset", 0);
Kokkos::View<int *, DeviceType> indices("indices", 0);
auto const start = std::chrono::high_resolution_clock::now();
index.query(ExecutionSpace{}, queries, indices, offset,
ArborX::Experimental::TraversalPolicy()
.setPredicateSorting(spec.sort_predicates)
.setBufferSize(spec.buffer_size));
ArborX::query(index, ExecutionSpace{}, queries, indices, offset,
ArborX::Experimental::TraversalPolicy()
.setPredicateSorting(spec.sort_predicates)
.setBufferSize(spec.buffer_size));
auto const end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed_seconds = end - start;
state.SetIterationTime(elapsed_seconds.count());
Expand Down
2 changes: 1 addition & 1 deletion examples/access_traits/example_cuda_access_traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ int main(int argc, char *argv[])

Kokkos::View<int *, Kokkos::CudaSpace> indices("indices", 0);
Kokkos::View<int *, Kokkos::CudaSpace> offset("offset", 0);
bvh.query(cuda, Spheres{d_a, d_a, d_a, d_a, N}, indices, offset);
ArborX::query(bvh, cuda, Spheres{d_a, d_a, d_a, d_a, N}, indices, offset);

Kokkos::parallel_for(Kokkos::RangePolicy<Kokkos::Cuda>(cuda, 0, N),
KOKKOS_LAMBDA(int i) {
Expand Down
42 changes: 18 additions & 24 deletions examples/callback/example_callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,6 @@ struct AccessTraits<NearestToOrigin, PredicatesTag>
};
} // namespace ArborX

struct PairIndexDistance
{
int index;
float distance;
};

struct PrintfCallback
{
template <typename Predicate, typename OutputFunctor>
Expand Down Expand Up @@ -96,39 +90,39 @@ int main(int argc, char *argv[])
{
Kokkos::View<int *, MemorySpace> values("values", 0);
Kokkos::View<int *, MemorySpace> offsets("offsets", 0);
bvh.query(ExecutionSpace{}, FirstOctant{}, PrintfCallback{}, values,
offsets);
ArborX::query(bvh, ExecutionSpace{}, FirstOctant{}, PrintfCallback{},
values, offsets);
#ifndef __NVCC__
bvh.query(ExecutionSpace{}, FirstOctant{},
KOKKOS_LAMBDA(auto /*predicate*/, int primitive,
auto /*output_functor*/) {
ArborX::query(bvh, ExecutionSpace{}, FirstOctant{},
KOKKOS_LAMBDA(auto /*predicate*/, int primitive,
auto /*output_functor*/) {
#ifndef __SYCL_DEVICE_ONLY__
printf("Found %d from generic lambda\n", primitive);
printf("Found %d from generic lambda\n", primitive);
#else
(void)primitive;
(void)primitive;
#endif
},
values, offsets);
},
values, offsets);
#endif
}

{
int const k = 10;
Kokkos::View<int *, MemorySpace> values("values", 0);
Kokkos::View<int *, MemorySpace> offsets("offsets", 0);
bvh.query(ExecutionSpace{}, NearestToOrigin{k}, PrintfCallback{}, values,
offsets);
ArborX::query(bvh, ExecutionSpace{}, NearestToOrigin{k}, PrintfCallback{},
values, offsets);
#ifndef __NVCC__
bvh.query(ExecutionSpace{}, NearestToOrigin{k},
KOKKOS_LAMBDA(auto /*predicate*/, int primitive,
auto /*output_functor*/) {
ArborX::query(bvh, ExecutionSpace{}, NearestToOrigin{k},
KOKKOS_LAMBDA(auto /*predicate*/, int primitive,
auto /*output_functor*/) {
#ifndef __SYCL_DEVICE_ONLY__
printf("Found %d from generic lambda\n", primitive);
printf("Found %d from generic lambda\n", primitive);
#else
(void)primitive;
(void)primitive;
#endif
},
values, offsets);
},
values, offsets);
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion examples/dbscan/ArborX_DBSCAN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void dbscan(ExecutionSpace exec_space, Primitives const &primitives,

Kokkos::View<int *, MemorySpace> indices("indices", 0);
Kokkos::View<int *, MemorySpace> offset("offset", 0);
bvh.query(exec_space, predicates, indices, offset);
ArborX::query(bvh, exec_space, predicates, indices, offset);

auto passed = Details::verifyClusters(exec_space, indices, offset, clusters,
core_min_size);
Expand Down
1 change: 1 addition & 0 deletions src/ArborX.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#ifdef ARBORX_ENABLE_MPI
#include <ArborX_DistributedTree.hpp>
#endif
#include <ArborX_CrsGraphWrapper.hpp>
#include <ArborX_Exception.hpp>
#include <ArborX_LinearBVH.hpp>
#include <ArborX_Point.hpp>
Expand Down
46 changes: 46 additions & 0 deletions src/ArborX_CrsGraphWrapper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/****************************************************************************
* Copyright (c) 2012-2020 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
* distributed under a BSD 3-clause license. For the licensing terms see *
* the LICENSE file in the top-level directory. *
* *
* SPDX-License-Identifier: BSD-3-Clause *
****************************************************************************/

#ifndef ARBORX_CRS_GRAPH_WRAPPER_HPP
#define ARBORX_CRS_GRAPH_WRAPPER_HPP

#include "ArborX_DetailsCrsGraphWrapperImpl.hpp"

namespace ArborX
{

template <typename Tree, typename ExecutionSpace, typename Predicates,
typename CallbackOrView, typename View, typename... Args>
inline void query(Tree const &tree, ExecutionSpace const &space,
Predicates const &predicates,
CallbackOrView &&callback_or_view, View &&view,
Args &&... args)
{
Kokkos::Profiling::pushRegion("ArborX::query");

Details::CrsGraphWrapperImpl::
check_valid_callback_if_first_argument_is_not_a_view(callback_or_view,
predicates, view);

using Access = AccessTraits<Predicates, Traits::PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;

ArborX::Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, tree, space, predicates,
std::forward<CallbackOrView>(callback_or_view), std::forward<View>(view),
std::forward<Args>(args)...);

Kokkos::Profiling::popRegion();
}

} // namespace ArborX

#endif
95 changes: 81 additions & 14 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,23 @@

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_DetailsBoundingVolumeHierarchyImpl.hpp>
#include <ArborX_Callbacks.hpp>
#include <ArborX_CrsGraphWrapper.hpp>
#include <ArborX_DetailsBatchedQueries.hpp>
#include <ArborX_DetailsConcepts.hpp>
#include <ArborX_DetailsKokkosExt.hpp>
#include <ArborX_DetailsNode.hpp>
#include <ArborX_DetailsPermutedData.hpp>
#include <ArborX_DetailsSortUtils.hpp>
#include <ArborX_DetailsTreeConstruction.hpp>
#include <ArborX_DetailsTreeTraversal.hpp>
#include <ArborX_TraversalPolicy.hpp>

#include <Kokkos_Core.hpp>

namespace ArborX
{

namespace Details
{
template <typename DeviceType>
Expand Down Expand Up @@ -57,18 +63,21 @@ class BoundingVolumeHierarchy
KOKKOS_FUNCTION
bounding_volume_type bounds() const noexcept { return _bounds; }

template <typename ExecutionSpace, typename Predicates, typename... Args>
template <typename ExecutionSpace, typename Predicates, typename Callback>
void query(ExecutionSpace const &space, Predicates const &predicates,
Args &&... args) const
Callback const &callback,
Experimental::TraversalPolicy const &policy =
Experimental::TraversalPolicy()) const;

template <typename ExecutionSpace, typename Predicates,
typename CallbackOrView, typename View, typename... Args>
std::enable_if_t<Kokkos::is_view<std::decay_t<View>>{}>
query(ExecutionSpace const &space, Predicates const &predicates,
CallbackOrView &&callback_or_view, View &&view, Args &&... args) const
{
Details::check_valid_access_traits(PredicatesTag{}, predicates);
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");

Details::BoundingVolumeHierarchyImpl::query(space, *this, predicates,
std::forward<Args>(args)...);
ArborX::query(*this, space, predicates,
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}

private:
Expand Down Expand Up @@ -167,11 +176,29 @@ class BoundingVolumeHierarchy<
{
}
// clang-format on
template <typename... Args>
void query(Args &&... args) const
template <typename FirstArgumentType, typename... Args>
std::enable_if_t<!Kokkos::is_execution_space<FirstArgumentType>::value>
query(FirstArgumentType &&arg1, Args &&... args) const
{
BoundingVolumeHierarchy<typename DeviceType::memory_space>::query(
typename DeviceType::execution_space{},
std::forward<FirstArgumentType>(arg1), std::forward<Args>(args)...);
}

private:
template <typename Tree, typename ExecutionSpace, typename Predicates,
typename CallbackOrView, typename View, typename... Args>
friend void ArborX::query(Tree const &tree, ExecutionSpace const &space,
Predicates const &predicates,
CallbackOrView &&callback_or_view, View &&view,
Args &&... args);

template <typename FirstArgumentType, typename... Args>
std::enable_if_t<Kokkos::is_execution_space<FirstArgumentType>::value>
query(FirstArgumentType const &space, Args &&... args) const
{
BoundingVolumeHierarchy<typename DeviceType::memory_space>::query(
typename DeviceType::execution_space{}, std::forward<Args>(args)...);
space, std::forward<Args>(args)...);
}
};

Expand Down Expand Up @@ -245,6 +272,46 @@ BoundingVolumeHierarchy<MemorySpace, Enable>::BoundingVolumeHierarchy(
Kokkos::Profiling::popRegion();
}

template <typename MemorySpace, typename Enable>
template <typename ExecutionSpace, typename Predicates, typename Callback>
void BoundingVolumeHierarchy<MemorySpace, Enable>::query(
ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Experimental::TraversalPolicy const &policy) const
{
Details::check_valid_access_traits(PredicatesTag{}, predicates);

using Access = AccessTraits<Predicates, Traits::PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;

auto profiling_prefix =
std::string("ArborX::BVH::query::") +
(std::is_same<Tag, Details::SpatialPredicateTag>{} ? "spatial"
: "nearest");

Kokkos::Profiling::pushRegion(profiling_prefix);

if (policy._sort_predicates)
{
Kokkos::Profiling::pushRegion(profiling_prefix + "::compute_permutation");
using DeviceType = Kokkos::Device<ExecutionSpace, MemorySpace>;
auto permute =
Details::BatchedQueries<DeviceType>::sortQueriesAlongZOrderCurve(
space, bounds(), predicates);
Kokkos::Profiling::popRegion();

using PermutedPredicates =
Details::PermutedData<Predicates, decltype(permute)>;
Details::traverse(space, *this, PermutedPredicates{predicates, permute},
callback);
}
else
{
Details::traverse(space, *this, predicates, callback);
}

Kokkos::Profiling::popRegion();
}

} // namespace ArborX

#endif
Loading

0 comments on commit 7040316

Please sign in to comment.