Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Dec 3, 2020
1 parent f11c397 commit 86dd5ee
Show file tree
Hide file tree
Showing 7 changed files with 518 additions and 462 deletions.
32 changes: 21 additions & 11 deletions benchmarks/bvh_driver/bvh_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
****************************************************************************/

#include <ArborX_BoostRTreeHelpers.hpp>
#include <ArborX_CrsGraphWrapper.hpp>
#include <ArborX_LinearBVH.hpp>
#include <ArborX_Version.hpp>

Expand Down Expand Up @@ -209,21 +210,25 @@ template <class TreeType>
void BM_knn_search(benchmark::State &state, Spec const &spec)
{
using DeviceType = typename TreeType::device_type;
using ExecutionSpace = typename DeviceType::execution_space;

TreeType index(
constructPoints<DeviceType>(spec.n_values, spec.source_point_cloud_type));
auto const queries = makeNearestQueries<DeviceType>(
spec.n_values, spec.n_queries, spec.n_neighbors,
spec.target_point_cloud_type);

ArborX::CrsGraphWrapper<TreeType> crs_graph_index{index};

for (auto _ : state)
{
Kokkos::View<int *, DeviceType> offset("offset", 0);
Kokkos::View<int *, DeviceType> indices("indices", 0);
auto const start = std::chrono::high_resolution_clock::now();
index.query(queries, indices, offset,
ArborX::Experimental::TraversalPolicy().setPredicateSorting(
spec.sort_predicates));
crs_graph_index.query(
ExecutionSpace{}, queries, indices, offset,
ArborX::Experimental::TraversalPolicy().setPredicateSorting(
spec.sort_predicates));
auto const end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed_seconds = end - start;
state.SetIterationTime(elapsed_seconds.count());
Expand All @@ -247,6 +252,7 @@ template <class TreeType>
void BM_knn_callback_search(benchmark::State &state, Spec const &spec)
{
using DeviceType = typename TreeType::device_type;
using ExecutionSpace = typename DeviceType::execution_space;

TreeType index(
constructPoints<DeviceType>(spec.n_values, spec.source_point_cloud_type));
Expand All @@ -262,7 +268,7 @@ void BM_knn_callback_search(benchmark::State &state, Spec const &spec)
NearestCallback<DeviceType> callback{num_neigh};

auto const start = std::chrono::high_resolution_clock::now();
index.query(queries, callback,
index.query(ExecutionSpace{}, queries, callback,
ArborX::Experimental::TraversalPolicy().setPredicateSorting(
spec.sort_predicates));
auto const end = std::chrono::high_resolution_clock::now();
Expand All @@ -275,22 +281,25 @@ template <class TreeType>
void BM_radius_search(benchmark::State &state, Spec const &spec)
{
using DeviceType = typename TreeType::device_type;
using ExecutionSpace = typename DeviceType::execution_space;

TreeType index(
constructPoints<DeviceType>(spec.n_values, spec.source_point_cloud_type));
auto const queries = makeSpatialQueries<DeviceType>(
spec.n_values, spec.n_queries, spec.n_neighbors,
spec.target_point_cloud_type);

ArborX::CrsGraphWrapper<TreeType> crs_graph_index{index};

for (auto _ : state)
{
Kokkos::View<int *, DeviceType> offset("offset", 0);
Kokkos::View<int *, DeviceType> indices("indices", 0);
auto const start = std::chrono::high_resolution_clock::now();
index.query(queries, indices, offset,
ArborX::Experimental::TraversalPolicy()
.setPredicateSorting(spec.sort_predicates)
.setBufferSize(spec.buffer_size));
crs_graph_index.query(ExecutionSpace{}, queries, indices, offset,
ArborX::Experimental::TraversalPolicy()
.setPredicateSorting(spec.sort_predicates)
.setBufferSize(spec.buffer_size));
auto const end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed_seconds = end - start;
state.SetIterationTime(elapsed_seconds.count());
Expand All @@ -314,6 +323,7 @@ template <class TreeType>
void BM_radius_callback_search(benchmark::State &state, Spec const &spec)
{
using DeviceType = typename TreeType::device_type;
using ExecutionSpace = typename DeviceType::execution_space;

TreeType index(
constructPoints<DeviceType>(spec.n_values, spec.source_point_cloud_type));
Expand All @@ -329,7 +339,7 @@ void BM_radius_callback_search(benchmark::State &state, Spec const &spec)
SpatialCallback<DeviceType> callback{num_neigh};

auto const start = std::chrono::high_resolution_clock::now();
index.query(queries, callback,
index.query(ExecutionSpace{}, queries, callback,
ArborX::Experimental::TraversalPolicy()
.setPredicateSorting(spec.sort_predicates)
.setBufferSize(spec.buffer_size));
Expand Down Expand Up @@ -587,8 +597,8 @@ int main(int argc, char *argv[])
#endif

#ifdef KOKKOS_ENABLE_SERIAL
if (spec.backends == "all" || spec.backends == "rtree")
register_benchmark<BoostExt::RTree<ArborX::Point>>("BoostRTree", spec);
// if (spec.backends == "all" || spec.backends == "rtree")
// register_benchmark<BoostExt::RTree<ArborX::Point>>("BoostRTree", spec);
#endif
}

Expand Down
43 changes: 43 additions & 0 deletions src/ArborX_CrsGraphWrapper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/****************************************************************************
* Copyright (c) 2012-2020 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
* distributed under a BSD 3-clause license. For the licensing terms see *
* the LICENSE file in the top-level directory. *
* *
* SPDX-License-Identifier: BSD-3-Clause *
****************************************************************************/

#ifndef ARBORX_CRS_GRAPH_WRAPPER_HPP
#define ARBORX_CRS_GRAPH_WRAPPER_HPP

#include "ArborX_DetailsCrsGraphWrapperImpl.hpp"

namespace ArborX
{

template <typename Tree>
class CrsGraphWrapper
{
public:
CrsGraphWrapper(Tree const &tree)
: _tree(tree)
{
}

template <typename ExecutionSpace, typename Predicates, typename... Args>
void query(ExecutionSpace const &space, Predicates const &predicates,
Args &&... args) const
{
Details::CrsGraphWrapperImpl::query(space, _tree, predicates,
std::forward<Args>(args)...);
}

private:
Tree _tree;
};

} // namespace ArborX

#endif
64 changes: 50 additions & 14 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,22 @@

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_DetailsBoundingVolumeHierarchyImpl.hpp>
#include <ArborX_Callbacks.hpp>
#include <ArborX_DetailsBatchedQueries.hpp>
#include <ArborX_DetailsConcepts.hpp>
#include <ArborX_DetailsKokkosExt.hpp>
#include <ArborX_DetailsNode.hpp>
#include <ArborX_DetailsPermutedData.hpp>
#include <ArborX_DetailsSortUtils.hpp>
#include <ArborX_DetailsTreeConstruction.hpp>
#include <ArborX_DetailsTreeTraversal.hpp>
#include <ArborX_TraversalPolicy.hpp>

#include <Kokkos_Core.hpp>

namespace ArborX
{

namespace Details
{
template <typename DeviceType>
Expand Down Expand Up @@ -55,19 +60,11 @@ class BoundingVolumeHierarchy
KOKKOS_FUNCTION
bounding_volume_type bounds() const noexcept { return _bounds; }

template <typename ExecutionSpace, typename Predicates, typename... Args>
template <typename ExecutionSpace, typename Predicates, typename Callback>
void query(ExecutionSpace const &space, Predicates const &predicates,
Args &&... args) const
{
Details::check_valid_access_traits(PredicatesTag{}, predicates);
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");

Details::BoundingVolumeHierarchyImpl::query(space, *this, predicates,
std::forward<Args>(args)...);
}
Callback const &callback,
Experimental::TraversalPolicy const &policy =
Experimental::TraversalPolicy()) const;

private:
template <typename BVH, typename Predicates, typename Callback,
Expand Down Expand Up @@ -154,7 +151,7 @@ class BoundingVolumeHierarchy<
void query(Args &&... args) const
{
BoundingVolumeHierarchy<typename DeviceType::memory_space>::query(
typename DeviceType::execution_space{}, std::forward<Args>(args)...);
std::forward<Args>(args)...);
}
};

Expand Down Expand Up @@ -228,6 +225,45 @@ BoundingVolumeHierarchy<MemorySpace, Enable>::BoundingVolumeHierarchy(
Kokkos::Profiling::popRegion();
}

template <typename MemorySpace, typename Enable>
template <typename ExecutionSpace, typename Predicates, typename Callback>
void BoundingVolumeHierarchy<MemorySpace, Enable>::query(
ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Experimental::TraversalPolicy const &policy) const
{
Details::check_valid_access_traits(PredicatesTag{}, predicates);
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");

Kokkos::Profiling::pushRegion("ArborX::BVH::query");

// TODO check signature of the callback

auto const &bvh = *this;
if (policy._sort_predicates)
{
Kokkos::Profiling::pushRegion("ArborX::BVH::query::compute_permutation");
using DeviceType = Kokkos::Device<ExecutionSpace, MemorySpace>;
auto permute =
Details::BatchedQueries<DeviceType>::sortQueriesAlongZOrderCurve(
space, bounds(), predicates);
Kokkos::Profiling::popRegion();

using PermutedPredicates =
Details::PermutedData<Predicates, decltype(permute)>;
Details::traverse(space, bvh, PermutedPredicates{predicates, permute},
callback);
}
else
{
Details::traverse(space, bvh, predicates, callback);
}

Kokkos::Profiling::popRegion();
}

} // namespace ArborX

#endif
Loading

0 comments on commit 86dd5ee

Please sign in to comment.