diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index aef663c35..c2c1a1a54 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -11,6 +11,8 @@ endif() add_subdirectory(access_traits) +add_subdirectory(callback) + find_package(Boost COMPONENTS program_options) if(Boost_FOUND) add_subdirectory(viz) diff --git a/examples/callback/CMakeLists.txt b/examples/callback/CMakeLists.txt new file mode 100644 index 000000000..eb1990c4a --- /dev/null +++ b/examples/callback/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(ArborX_Callback.exe example_callback.cpp) +target_link_libraries(ArborX_Callback.exe ${ArborX_TARGET}) +add_test(NAME ArborX_Callback_Example COMMAND ./ArborX_Callback.exe) diff --git a/examples/callback/example_callback.cpp b/examples/callback/example_callback.cpp new file mode 100644 index 000000000..372a805bf --- /dev/null +++ b/examples/callback/example_callback.cpp @@ -0,0 +1,149 @@ +/**************************************************************************** + * Copyright (c) 2012-2020 by the ArborX authors * + * All rights reserved. * + * * + * This file is part of the ArborX library. ArborX is * + * distributed under a BSD 3-clause license. For the licensing terms see * + * the LICENSE file in the top-level directory. * + * * + * SPDX-License-Identifier: BSD-3-Clause * + ****************************************************************************/ + +#include + +#include + +#include +#include +#include + +using ExecutionSpace = Kokkos::DefaultExecutionSpace; +using MemorySpace = ExecutionSpace::memory_space; + +struct FirstOctant +{ +}; + +struct NearestToOrigin +{ + int k; +}; + +namespace ArborX +{ +namespace Traits +{ +template <> +struct Access +{ + KOKKOS_FUNCTION static std::size_t size(FirstOctant) { return 1; } + KOKKOS_FUNCTION static auto get(FirstOctant, std::size_t) + { + return intersects(Box{{{0, 0, 0}}, {{1, 1, 1}}}); + } + using memory_space = MemorySpace; +}; +template <> +struct Access +{ + KOKKOS_FUNCTION static std::size_t size(NearestToOrigin) { return 1; } + KOKKOS_FUNCTION static auto get(NearestToOrigin d, std::size_t) + { + return nearest(Point{0, 0, 0}, d.k); + } + using memory_space = MemorySpace; +}; +} // namespace Traits +} // namespace ArborX + +struct PairIndexDistance +{ + int index; + float distance; +}; + +struct PrintfCallback +{ + template + KOKKOS_FUNCTION void operator()(Predicate, int primitive, + OutputFunctor const &out) const + { + printf("Found %d from functor\n", primitive); + out(primitive); + } + template + KOKKOS_FUNCTION void operator()(Predicate, int primitive, float distance, + OutputFunctor const &out) const + { + printf("Found %d with distance %.3f from functor\n", primitive, distance); + out({primitive, distance}); + } +}; + +int main(int argc, char *argv[]) +{ + Kokkos::ScopeGuard guard(argc, argv); + + int const n = 100; + std::vector points; + // Fill vector with random points in [-1, 1]^3 + std::uniform_real_distribution dis{-1., 1.}; + std::default_random_engine gen; + auto rd = [&]() { return dis(gen); }; + std::generate_n(std::back_inserter(points), n, [&]() { + return ArborX::Point{rd(), rd(), rd()}; + }); + + ArborX::BVH bvh{ + ExecutionSpace{}, + Kokkos::create_mirror_view_and_copy( + MemorySpace{}, + Kokkos::View(points.data(), points.size()))}; + + { + Kokkos::View values("values", 0); + Kokkos::View offsets("offsets", 0); + bvh.query(ExecutionSpace{}, FirstOctant{}, PrintfCallback{}, values, + offsets); +#ifndef __NVCC__ + bvh.query(ExecutionSpace{}, FirstOctant{}, + KOKKOS_LAMBDA(auto /*predicate*/, int primitive, + auto /*output_functor*/) { + printf("Found %d from generic lambda\n", primitive); + }, + values, offsets); +#endif + } + + { + int const k = 10; + Kokkos::View values("values", 0); + Kokkos::View offsets("offsets", 0); + bvh.query(ExecutionSpace{}, NearestToOrigin{k}, PrintfCallback{}, values, + offsets); +#ifndef __NVCC__ + bvh.query(ExecutionSpace{}, NearestToOrigin{k}, + KOKKOS_LAMBDA(auto /*predicate*/, int primitive, float distance, + auto /*output_functor*/) { + printf("Found %d with distance %.3f from generic lambda\n", + primitive, distance); + }, + values, offsets); +#endif + } + + { + // EXPERIMENTAL + // TODO replace with BVH::query(ExecutionSpace, Predicates, Callback) when + // new overload is added + Kokkos::View> c( + "counter"); + + ArborX::Details::traverse( + ExecutionSpace{}, bvh, FirstOctant{}, + KOKKOS_LAMBDA(int i, int j) { printf("%d %d %d\n", ++c(), i, j); }); + } + + return 0; +} diff --git a/src/details/ArborX_Callbacks.hpp b/src/details/ArborX_Callbacks.hpp index a52e67fde..f2ba99b11 100644 --- a/src/details/ArborX_Callbacks.hpp +++ b/src/details/ArborX_Callbacks.hpp @@ -79,6 +79,13 @@ using SpatialPredicateInlineCallbackArchetypeExpression = template using CallbackTagArchetypeAlias = typename Callback::tag; +template +struct is_tagged_post_callback + : std::is_same, + PostCallbackTag>::type +{ +}; + // output functor to pass to the callback during detection template struct Sink @@ -93,26 +100,21 @@ template void check_valid_callback(Callback const &, Predicates const &, OutputView const &) { - static_assert(is_detected{}, - "Callback must define 'tag' member type"); - - using CallbackTag = detected_t; - static_assert(std::is_same{} || - std::is_same{}, - "Tag must be either 'InlineCallbackTag' or 'PostCallbackTag'"); +#ifdef __NVCC__ + // Without it would get a segmentation fault and no diagnostic whatsoever + static_assert( + !__nv_is_extended_host_device_lambda_closure_type(Callback), + "__host__ __device__ extended lambdas cannot be generic lambdas"); +#endif using Access = Traits::Access; using PredicateTag = typename Traits::Helper::tag; using Predicate = typename Traits::Helper::type; - // FIXME - constexpr bool short_circuit = std::is_same{}; static_assert( - short_circuit || - (std::is_same{} && - is_detected>{}) || + (std::is_same{} && + is_detected>{}) || (std::is_same{} && is_detected>{}), diff --git a/src/details/ArborX_DetailsBoundingVolumeHierarchyImpl.hpp b/src/details/ArborX_DetailsBoundingVolumeHierarchyImpl.hpp index 0df4856db..a2b66b437 100644 --- a/src/details/ArborX_DetailsBoundingVolumeHierarchyImpl.hpp +++ b/src/details/ArborX_DetailsBoundingVolumeHierarchyImpl.hpp @@ -117,7 +117,8 @@ namespace BoundingVolumeHierarchyImpl // is called. template -std::enable_if_t::value> +std::enable_if_t{} && + Kokkos::is_view{} && Kokkos::is_view{}> queryDispatch(SpatialPredicateTag, BVH const &bvh, ExecutionSpace const &space, Predicates const &predicates, Callback const &callback, OutputView &out, OffsetView &offset, @@ -189,8 +190,7 @@ queryDispatch(SpatialPredicateTag, BVH const &bvh, ExecutionSpace const &space, template -inline std::enable_if_t< - std::is_same::value> +inline std::enable_if_t{}> queryDispatch(SpatialPredicateTag, BVH const &bvh, ExecutionSpace const &space, Predicates const &predicates, Callback const &callback, OutputView &out, OffsetView &offset, @@ -206,7 +206,8 @@ queryDispatch(SpatialPredicateTag, BVH const &bvh, ExecutionSpace const &space, template -std::enable_if_t::value> +std::enable_if_t{} && + Kokkos::is_view{} && Kokkos::is_view{}> queryDispatch(NearestPredicateTag, BVH const &bvh, ExecutionSpace const &space, Predicates const &predicates, Callback const &callback, OutputView &out, OffsetView &offset, @@ -266,8 +267,7 @@ queryDispatch(NearestPredicateTag, BVH const &bvh, ExecutionSpace const &space, template -inline std::enable_if_t< - std::is_same::value> +inline std::enable_if_t{}> queryDispatch(NearestPredicateTag, BVH const &bvh, ExecutionSpace const &space, Predicates const &predicates, Callback const &callback, OutputView &out, OffsetView &offset, @@ -325,7 +325,8 @@ queryDispatch(NearestPredicateTag, BVH const &bvh, ExecutionSpace const &space, } // namespace BoundingVolumeHierarchyImpl template -std::enable_if_t{}> +std::enable_if_t{} && + !is_tagged_post_callback{}> check_valid_callback_if_first_argument_is_not_a_view( Callback const &callback, Predicates const &predicates, OutputView const &out) @@ -333,6 +334,16 @@ check_valid_callback_if_first_argument_is_not_a_view( check_valid_callback(callback, predicates, out); } +template +std::enable_if_t{} && + is_tagged_post_callback{}> +check_valid_callback_if_first_argument_is_not_a_view(Callback const &, + Predicates const &, + OutputView const &) +{ + // TODO +} + template std::enable_if_t{}> check_valid_callback_if_first_argument_is_not_a_view(View const &, diff --git a/test/tstCallbacks.cpp b/test/tstCallbacks.cpp index bbabfa04c..51ca9c7c2 100644 --- a/test/tstCallbacks.cpp +++ b/test/tstCallbacks.cpp @@ -51,6 +51,18 @@ struct NearestPredicateCallbackMissingTag } }; +struct Wrong +{ +}; + +struct SpatialPredicateCallbackDoesNotTakeCorrectArgument +{ + template + void operator()(Wrong, int, OutputFunctor const &) const + { + } +}; + int main() { using ArborX::Details::check_valid_callback; @@ -68,11 +80,32 @@ int main() ArborX::Details::CallbackDefaultNearestPredicateWithDistance{}, NearestPredicates{}, v); + // not required to tag inline callbacks any more + check_valid_callback(SpatialPredicateCallbackMissingTag{}, + SpatialPredicates{}, v); + + check_valid_callback(NearestPredicateCallbackMissingTag{}, + NearestPredicates{}, v); + + // generic lambdas are supported if not using NVCC +#ifndef __NVCC__ + check_valid_callback([](auto const & /*predicate*/, int /*primitive*/, + auto const & /*out*/) {}, + SpatialPredicates{}, v); + + check_valid_callback([](auto const & /*predicate*/, int /*primitive*/, + float /*distance*/, auto const & /*out*/) {}, + NearestPredicates{}, v); +#endif + // Uncomment to see error messages - // check_valid_callback(SpatialPredicateCallbackMissingTag{}, + // check_valid_callback(SpatialPredicateCallbackDoesNotTakeCorrectArgument{}, // SpatialPredicates{}, v); - // check_valid_callback(NearestPredicateCallbackMissingTag{}, + // check_valid_callback(ArborX::Details::CallbackDefaultSpatialPredicate{}, // NearestPredicates{}, v); + + // check_valid_callback(ArborX::Details::CallbackDefaultNearestPredicate{}, + // SpatialPredicates{}, v); }