Skip to content

Commit

Permalink
Merge pull request #325 from dalg24/callback
Browse files Browse the repository at this point in the history
Add example for callbacks and lift requirement for tagging inline
  • Loading branch information
dalg24 committed May 29, 2020
2 parents 5d8f04e + 67f0e60 commit 59cbe5c
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 23 deletions.
2 changes: 2 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ endif()

add_subdirectory(access_traits)

add_subdirectory(callback)

find_package(Boost COMPONENTS program_options)
if(Boost_FOUND)
add_subdirectory(viz)
Expand Down
3 changes: 3 additions & 0 deletions examples/callback/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
add_executable(ArborX_Callback.exe example_callback.cpp)
target_link_libraries(ArborX_Callback.exe ${ArborX_TARGET})
add_test(NAME ArborX_Callback_Example COMMAND ./ArborX_Callback.exe)
149 changes: 149 additions & 0 deletions examples/callback/example_callback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/****************************************************************************
* Copyright (c) 2012-2020 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
* distributed under a BSD 3-clause license. For the licensing terms see *
* the LICENSE file in the top-level directory. *
* *
* SPDX-License-Identifier: BSD-3-Clause *
****************************************************************************/

#include <ArborX.hpp>

#include <Kokkos_Core.hpp>

#include <iostream>
#include <random>
#include <vector>

using ExecutionSpace = Kokkos::DefaultExecutionSpace;
using MemorySpace = ExecutionSpace::memory_space;

struct FirstOctant
{
};

struct NearestToOrigin
{
int k;
};

namespace ArborX
{
namespace Traits
{
template <>
struct Access<FirstOctant, PredicatesTag>
{
KOKKOS_FUNCTION static std::size_t size(FirstOctant) { return 1; }
KOKKOS_FUNCTION static auto get(FirstOctant, std::size_t)
{
return intersects(Box{{{0, 0, 0}}, {{1, 1, 1}}});
}
using memory_space = MemorySpace;
};
template <>
struct Access<NearestToOrigin, PredicatesTag>
{
KOKKOS_FUNCTION static std::size_t size(NearestToOrigin) { return 1; }
KOKKOS_FUNCTION static auto get(NearestToOrigin d, std::size_t)
{
return nearest(Point{0, 0, 0}, d.k);
}
using memory_space = MemorySpace;
};
} // namespace Traits
} // namespace ArborX

struct PairIndexDistance
{
int index;
float distance;
};

struct PrintfCallback
{
template <typename Predicate, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate, int primitive,
OutputFunctor const &out) const
{
printf("Found %d from functor\n", primitive);
out(primitive);
}
template <typename Predicate, typename OutputFunctor>
KOKKOS_FUNCTION void operator()(Predicate, int primitive, float distance,
OutputFunctor const &out) const
{
printf("Found %d with distance %.3f from functor\n", primitive, distance);
out({primitive, distance});
}
};

int main(int argc, char *argv[])
{
Kokkos::ScopeGuard guard(argc, argv);

int const n = 100;
std::vector<ArborX::Point> points;
// Fill vector with random points in [-1, 1]^3
std::uniform_real_distribution<float> dis{-1., 1.};
std::default_random_engine gen;
auto rd = [&]() { return dis(gen); };
std::generate_n(std::back_inserter(points), n, [&]() {
return ArborX::Point{rd(), rd(), rd()};
});

ArborX::BVH<MemorySpace> bvh{
ExecutionSpace{},
Kokkos::create_mirror_view_and_copy(
MemorySpace{},
Kokkos::View<ArborX::Point *, Kokkos::HostSpace,
Kokkos::MemoryUnmanaged>(points.data(), points.size()))};

{
Kokkos::View<int *, MemorySpace> values("values", 0);
Kokkos::View<int *, MemorySpace> offsets("offsets", 0);
bvh.query(ExecutionSpace{}, FirstOctant{}, PrintfCallback{}, values,
offsets);
#ifndef __NVCC__
bvh.query(ExecutionSpace{}, FirstOctant{},
KOKKOS_LAMBDA(auto /*predicate*/, int primitive,
auto /*output_functor*/) {
printf("Found %d from generic lambda\n", primitive);
},
values, offsets);
#endif
}

{
int const k = 10;
Kokkos::View<PairIndexDistance *, MemorySpace> values("values", 0);
Kokkos::View<int *, MemorySpace> offsets("offsets", 0);
bvh.query(ExecutionSpace{}, NearestToOrigin{k}, PrintfCallback{}, values,
offsets);
#ifndef __NVCC__
bvh.query(ExecutionSpace{}, NearestToOrigin{k},
KOKKOS_LAMBDA(auto /*predicate*/, int primitive, float distance,
auto /*output_functor*/) {
printf("Found %d with distance %.3f from generic lambda\n",
primitive, distance);
},
values, offsets);
#endif
}

{
// EXPERIMENTAL
// TODO replace with BVH::query(ExecutionSpace, Predicates, Callback) when
// new overload is added
Kokkos::View<int, ExecutionSpace, Kokkos::MemoryTraits<Kokkos::Atomic>> c(
"counter");

ArborX::Details::traverse(
ExecutionSpace{}, bvh, FirstOctant{},
KOKKOS_LAMBDA(int i, int j) { printf("%d %d %d\n", ++c(), i, j); });
}

return 0;
}
30 changes: 16 additions & 14 deletions src/details/ArborX_Callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ using SpatialPredicateInlineCallbackArchetypeExpression =
template <typename Callback>
using CallbackTagArchetypeAlias = typename Callback::tag;

template <typename Callback>
struct is_tagged_post_callback
: std::is_same<detected_t<CallbackTagArchetypeAlias, Callback>,
PostCallbackTag>::type
{
};

// output functor to pass to the callback during detection
template <typename T>
struct Sink
Expand All @@ -93,26 +100,21 @@ template <typename Callback, typename Predicates, typename OutputView>
void check_valid_callback(Callback const &, Predicates const &,
OutputView const &)
{
static_assert(is_detected<CallbackTagArchetypeAlias, Callback>{},
"Callback must define 'tag' member type");

using CallbackTag = detected_t<CallbackTagArchetypeAlias, Callback>;
static_assert(std::is_same<CallbackTag, InlineCallbackTag>{} ||
std::is_same<CallbackTag, PostCallbackTag>{},
"Tag must be either 'InlineCallbackTag' or 'PostCallbackTag'");
#ifdef __NVCC__
// Without it would get a segmentation fault and no diagnostic whatsoever
static_assert(
!__nv_is_extended_host_device_lambda_closure_type(Callback),
"__host__ __device__ extended lambdas cannot be generic lambdas");
#endif

using Access = Traits::Access<Predicates, Traits::PredicatesTag>;
using PredicateTag = typename Traits::Helper<Access>::tag;
using Predicate = typename Traits::Helper<Access>::type;

// FIXME
constexpr bool short_circuit = std::is_same<CallbackTag, PostCallbackTag>{};
static_assert(
short_circuit ||
(std::is_same<PredicateTag, SpatialPredicateTag>{} &&
is_detected<SpatialPredicateInlineCallbackArchetypeExpression,
Callback, Predicate,
OutputFunctorHelper<OutputView>>{}) ||
(std::is_same<PredicateTag, SpatialPredicateTag>{} &&
is_detected<SpatialPredicateInlineCallbackArchetypeExpression, Callback,
Predicate, OutputFunctorHelper<OutputView>>{}) ||
(std::is_same<PredicateTag, NearestPredicateTag>{} &&
is_detected<NearestPredicateInlineCallbackArchetypeExpression,
Callback, Predicate, OutputFunctorHelper<OutputView>>{}),
Expand Down
25 changes: 18 additions & 7 deletions src/details/ArborX_DetailsBoundingVolumeHierarchyImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ namespace BoundingVolumeHierarchyImpl
// is called.
template <typename BVH, typename ExecutionSpace, typename Predicates,
typename OutputView, typename OffsetView, typename Callback>
std::enable_if_t<std::is_same<typename Callback::tag, InlineCallbackTag>::value>
std::enable_if_t<!is_tagged_post_callback<Callback>{} &&
Kokkos::is_view<OutputView>{} && Kokkos::is_view<OffsetView>{}>
queryDispatch(SpatialPredicateTag, BVH const &bvh, ExecutionSpace const &space,
Predicates const &predicates, Callback const &callback,
OutputView &out, OffsetView &offset,
Expand Down Expand Up @@ -189,8 +190,7 @@ queryDispatch(SpatialPredicateTag, BVH const &bvh, ExecutionSpace const &space,

template <typename BVH, typename ExecutionSpace, typename Predicates,
typename OutputView, typename OffsetView, typename Callback>
inline std::enable_if_t<
std::is_same<typename Callback::tag, PostCallbackTag>::value>
inline std::enable_if_t<is_tagged_post_callback<Callback>{}>
queryDispatch(SpatialPredicateTag, BVH const &bvh, ExecutionSpace const &space,
Predicates const &predicates, Callback const &callback,
OutputView &out, OffsetView &offset,
Expand All @@ -206,7 +206,8 @@ queryDispatch(SpatialPredicateTag, BVH const &bvh, ExecutionSpace const &space,

template <typename BVH, typename ExecutionSpace, typename Predicates,
typename OutputView, typename OffsetView, typename Callback>
std::enable_if_t<std::is_same<typename Callback::tag, InlineCallbackTag>::value>
std::enable_if_t<!is_tagged_post_callback<Callback>{} &&
Kokkos::is_view<OutputView>{} && Kokkos::is_view<OffsetView>{}>
queryDispatch(NearestPredicateTag, BVH const &bvh, ExecutionSpace const &space,
Predicates const &predicates, Callback const &callback,
OutputView &out, OffsetView &offset,
Expand Down Expand Up @@ -266,8 +267,7 @@ queryDispatch(NearestPredicateTag, BVH const &bvh, ExecutionSpace const &space,

template <typename BVH, typename ExecutionSpace, typename Predicates,
typename OutputView, typename OffsetView, typename Callback>
inline std::enable_if_t<
std::is_same<typename Callback::tag, PostCallbackTag>::value>
inline std::enable_if_t<is_tagged_post_callback<Callback>{}>
queryDispatch(NearestPredicateTag, BVH const &bvh, ExecutionSpace const &space,
Predicates const &predicates, Callback const &callback,
OutputView &out, OffsetView &offset,
Expand Down Expand Up @@ -325,14 +325,25 @@ queryDispatch(NearestPredicateTag, BVH const &bvh, ExecutionSpace const &space,
} // namespace BoundingVolumeHierarchyImpl

template <typename Callback, typename Predicates, typename OutputView>
std::enable_if_t<!Kokkos::is_view<Callback>{}>
std::enable_if_t<!Kokkos::is_view<Callback>{} &&
!is_tagged_post_callback<Callback>{}>
check_valid_callback_if_first_argument_is_not_a_view(
Callback const &callback, Predicates const &predicates,
OutputView const &out)
{
check_valid_callback(callback, predicates, out);
}

template <typename Callback, typename Predicates, typename OutputView>
std::enable_if_t<!Kokkos::is_view<Callback>{} &&
is_tagged_post_callback<Callback>{}>
check_valid_callback_if_first_argument_is_not_a_view(Callback const &,
Predicates const &,
OutputView const &)
{
// TODO
}

template <typename View, typename Predicates, typename OutputView>
std::enable_if_t<Kokkos::is_view<View>{}>
check_valid_callback_if_first_argument_is_not_a_view(View const &,
Expand Down
37 changes: 35 additions & 2 deletions test/tstCallbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ struct NearestPredicateCallbackMissingTag
}
};

struct Wrong
{
};

struct SpatialPredicateCallbackDoesNotTakeCorrectArgument
{
template <typename OutputFunctor>
void operator()(Wrong, int, OutputFunctor const &) const
{
}
};

int main()
{
using ArborX::Details::check_valid_callback;
Expand All @@ -68,11 +80,32 @@ int main()
ArborX::Details::CallbackDefaultNearestPredicateWithDistance{},
NearestPredicates{}, v);

// not required to tag inline callbacks any more
check_valid_callback(SpatialPredicateCallbackMissingTag{},
SpatialPredicates{}, v);

check_valid_callback(NearestPredicateCallbackMissingTag{},
NearestPredicates{}, v);

// generic lambdas are supported if not using NVCC
#ifndef __NVCC__
check_valid_callback([](auto const & /*predicate*/, int /*primitive*/,
auto const & /*out*/) {},
SpatialPredicates{}, v);

check_valid_callback([](auto const & /*predicate*/, int /*primitive*/,
float /*distance*/, auto const & /*out*/) {},
NearestPredicates{}, v);
#endif

// Uncomment to see error messages

// check_valid_callback(SpatialPredicateCallbackMissingTag{},
// check_valid_callback(SpatialPredicateCallbackDoesNotTakeCorrectArgument{},
// SpatialPredicates{}, v);

// check_valid_callback(NearestPredicateCallbackMissingTag{},
// check_valid_callback(ArborX::Details::CallbackDefaultSpatialPredicate{},
// NearestPredicates{}, v);

// check_valid_callback(ArborX::Details::CallbackDefaultNearestPredicate{},
// SpatialPredicates{}, v);
}

0 comments on commit 59cbe5c

Please sign in to comment.