Skip to content

Commit

Permalink
Try harder to unwrap nested thrust::tuple_of_iterator_references
Browse files Browse the repository at this point in the history
We tried to simply unpack the `tuple_of_iterator_references`, however, if it contained nested `tuple_of_iterator_references` then that would break down. Instead recursively apply the unwrapping when possible
  • Loading branch information
miscco committed Mar 1, 2024
1 parent 2acbea2 commit ec9840c
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 101 deletions.
102 changes: 86 additions & 16 deletions thrust/testing/zip_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@

#if _CCCL_STD_VER >= 2011 && !defined(THRUST_LEGACY_GCC)

#include <unittest/unittest.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/transform.h>
#include <thrust/zip_function.h>
# include <thrust/device_vector.h>
# include <thrust/iterator/zip_iterator.h>
# include <thrust/remove.h>
# include <thrust/transform.h>
# include <thrust/zip_function.h>

#include <iostream>
# include <iostream>

# include <unittest/unittest.h>

using namespace unittest;

struct SumThree
{
template <typename T1, typename T2, typename T3>
__host__ __device__
auto operator()(T1 x, T2 y, T3 z) const
THRUST_DECLTYPE_RETURNS(x + y + z)
__host__ __device__ auto operator()(T1 x, T2 y, T3 z) const THRUST_DECLTYPE_RETURNS(x + y + z)
}; // end SumThree

struct SumThreeTuple
{
template <typename Tuple>
__host__ __device__
auto operator()(Tuple x) const
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
__host__ __device__ auto operator()(Tuple x) const
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
}; // end SumThreeTuple

template <typename T>
Expand All @@ -42,22 +42,22 @@ struct TestZipFunctionTransform
device_vector<T> d_data1 = h_data1;
device_vector<T> d_data2 = h_data2;

host_vector<T> h_result_tuple(n);
host_vector<T> h_result_zip(n);
host_vector<T> h_result_tuple(n);
host_vector<T> h_result_zip(n);
device_vector<T> d_result_zip(n);

// Tuple base case
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
h_result_tuple.begin(),
SumThreeTuple{});
// Zip Function
transform(make_zip_iterator(make_tuple(h_data0.begin(), h_data1.begin(), h_data2.begin())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
make_zip_iterator(make_tuple(h_data0.end(), h_data1.end(), h_data2.end())),
h_result_zip.begin(),
make_zip_function(SumThree{}));
transform(make_zip_iterator(make_tuple(d_data0.begin(), d_data1.begin(), d_data2.begin())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end(), d_data2.end())),
make_zip_iterator(make_tuple(d_data0.end(), d_data1.end(), d_data2.end())),
d_result_zip.begin(),
make_zip_function(SumThree{}));

Expand All @@ -67,4 +67,74 @@ struct TestZipFunctionTransform
};
VariableUnitTest<TestZipFunctionTransform, ThirtyTwoBitTypes> TestZipFunctionTransformInstance;

struct RemovePred
{
__host__ __device__ bool operator()(const thrust::tuple<uint32_t, uint32_t>& ele1, const float&)
{
return thrust::get<0>(ele1) == thrust::get<1>(ele1);
}
};
template <typename T>
struct TestZipFunctionMixed
{
void operator()()
{
thrust::device_vector<uint32_t> vecA{0, 0, 2, 0};
thrust::device_vector<uint32_t> vecB{0, 2, 2, 2};
thrust::device_vector<float> vecC{88.0f, 88.0f, 89.0f, 89.0f};
thrust::device_vector<float> expected{88.0f, 89.0f};

auto inputKeyItBegin =
thrust::make_zip_iterator(thrust::make_zip_iterator(vecA.begin(), vecB.begin()), vecC.begin());
auto endIt =
thrust::remove_if(inputKeyItBegin, inputKeyItBegin + vecA.size(), thrust::make_zip_function(RemovePred{}));
auto numEle = endIt - inputKeyItBegin;
vecA.resize(numEle);
vecB.resize(numEle);
vecC.resize(numEle);

ASSERT_EQUAL(numEle, 2);
ASSERT_EQUAL(vecC, expected);
}
};
SimpleUnitTest<TestZipFunctionMixed, type_list<int, float> > TestZipFunctionMixedInstance;

struct NestedFunctionCall
{
__host__ __device__ bool
operator()(const thrust::tuple<uint32_t, thrust::tuple<thrust::tuple<int, int>, thrust::tuple<int, int>>>& idAndPt)
{
thrust::tuple<thrust::tuple<int, int>, thrust::tuple<int, int>> ele1 = thrust::get<1>(idAndPt);
thrust::tuple<int, int> p1 = thrust::get<0>(ele1);
thrust::tuple<int, int> p2 = thrust::get<1>(ele1);
return thrust::get<0>(p1) == thrust::get<0>(p2) || thrust::get<1>(p1) == thrust::get<1>(p2);
}
};

template <typename T>
struct TestNestedZipFunction
{
void operator()()
{
thrust::device_vector<int> PX{0, 1, 2, 3};
thrust::device_vector<int> PY{0, 1, 2, 2};
thrust::device_vector<uint32_t> SS{0, 1, 2};
thrust::device_vector<uint32_t> ST{1, 2, 3};
thrust::device_vector<float> vecC{88.0f, 88.0f, 89.0f, 89.0f};

auto segIt = thrust::make_zip_iterator(
thrust::make_zip_iterator(thrust::make_permutation_iterator(PX.begin(), SS.begin()),
thrust::make_permutation_iterator(PY.begin(), SS.begin())),
thrust::make_zip_iterator(thrust::make_permutation_iterator(PX.begin(), ST.begin()),
thrust::make_permutation_iterator(PY.begin(), ST.begin())));
auto idAndSegIt = thrust::make_zip_iterator(thrust::make_counting_iterator(0u), segIt);

thrust::device_vector<bool> isMH{false, false, false};
thrust::device_vector<bool> expected{false, false, true};
thrust::transform(idAndSegIt, idAndSegIt + SS.size(), isMH.begin(), NestedFunctionCall{});
ASSERT_EQUAL(isMH, expected);
}
};
SimpleUnitTest<TestNestedZipFunction, type_list<int, float> > TestNestedZipFunctionInstance;

#endif // _CCCL_STD_VER
178 changes: 93 additions & 85 deletions thrust/thrust/iterator/detail/tuple_of_iterator_references.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,106 +26,113 @@
# pragma system_header
#endif // no system header

#include <cuda/std/type_traits>
#include <cuda/std/tuple>

#include <thrust/tuple.h>
#include <thrust/pair.h>
#include <thrust/detail/reference_forward_declaration.h>
#include <thrust/detail/raw_reference_cast.h>
#include <thrust/detail/reference_forward_declaration.h>
#include <thrust/pair.h>
#include <thrust/tuple.h>

#include <cuda/std/tuple>
#include <cuda/std/type_traits>

THRUST_NAMESPACE_BEGIN

namespace detail
{

template<
typename... Ts
>
class tuple_of_iterator_references : public thrust::tuple<Ts...>
template <typename... Ts >
class tuple_of_iterator_references;

template<class U, class T>
struct maybe_unwrap_nested {
_CCCL_HOST_DEVICE U operator()(const T& t) const {
return t;
}
};

template<class... Us, class... Ts>
struct maybe_unwrap_nested<thrust::tuple<Us...>, tuple_of_iterator_references<Ts...>> {
_CCCL_HOST_DEVICE thrust::tuple<Us...> operator()(const tuple_of_iterator_references<Ts...>& t) const {
return t.template __to_tuple<Us...>(typename ::cuda::std::__make_tuple_indices<sizeof...(Ts)>::type{});
}
};

template < typename... Ts >
class tuple_of_iterator_references : public thrust::tuple<Ts...>
{
public:
using super_t = thrust::tuple<Ts...>;
using super_t::super_t;
public:
using super_t = thrust::tuple<Ts...>;
using super_t::super_t;

inline _CCCL_HOST_DEVICE
tuple_of_iterator_references()
inline _CCCL_HOST_DEVICE tuple_of_iterator_references()
: super_t()
{}
{}

// allow implicit construction from tuple<refs>
inline _CCCL_HOST_DEVICE
tuple_of_iterator_references(const super_t& other)
// allow implicit construction from tuple<refs>
inline _CCCL_HOST_DEVICE tuple_of_iterator_references(const super_t& other)
: super_t(other)
{}
{}

inline _CCCL_HOST_DEVICE
tuple_of_iterator_references(super_t&& other)
inline _CCCL_HOST_DEVICE tuple_of_iterator_references(super_t&& other)
: super_t(::cuda::std::move(other))
{}

// allow assignment from tuples
// XXX might be worthwhile to guard this with an enable_if is_assignable
_CCCL_EXEC_CHECK_DISABLE
template<typename... Us>
inline _CCCL_HOST_DEVICE
tuple_of_iterator_references &operator=(const thrust::tuple<Us...> &other)
{
super_t::operator=(other);
return *this;
}

// allow assignment from pairs
// XXX might be worthwhile to guard this with an enable_if is_assignable
_CCCL_EXEC_CHECK_DISABLE
template<typename U1, typename U2>
inline _CCCL_HOST_DEVICE
tuple_of_iterator_references &operator=(const thrust::pair<U1,U2> &other)
{
super_t::operator=(other);
return *this;
}

// allow assignment from reference<tuple>
// XXX perhaps we should generalize to reference<T>
// we could captures reference<pair> this way
_CCCL_EXEC_CHECK_DISABLE
template<typename Pointer, typename Derived, typename... Us>
inline _CCCL_HOST_DEVICE
tuple_of_iterator_references&
operator=(const thrust::reference<thrust::tuple<Us...>, Pointer, Derived> &other)
{
typedef thrust::tuple<Us...> tuple_type;

// XXX perhaps this could be accelerated
super_t::operator=(tuple_type{other});
return *this;
}

template<class... Us, ::cuda::std::__enable_if_t<sizeof...(Us) == sizeof...(Ts), int> = 0>
inline _CCCL_HOST_DEVICE
constexpr operator thrust::tuple<Us...>() const {
return to_tuple<Us...>(typename ::cuda::std::__make_tuple_indices<sizeof...(Ts)>::type{});
}

// this overload of swap() permits swapping tuple_of_iterator_references returned as temporaries from
// iterator dereferences
template<class... Us>
inline _CCCL_HOST_DEVICE
friend void swap(tuple_of_iterator_references&& x, tuple_of_iterator_references<Us...>&& y)
{
x.swap(y);
}

private:
template<class... Us, size_t... Id>
inline _CCCL_HOST_DEVICE
constexpr thrust::tuple<Us...> to_tuple(::cuda::std::__tuple_indices<Id...>) const {
return {get<Id>(*this)...};
}
{}

// allow assignment from tuples
// XXX might be worthwhile to guard this with an enable_if is_assignable
_CCCL_EXEC_CHECK_DISABLE
template <typename... Us>
inline _CCCL_HOST_DEVICE tuple_of_iterator_references& operator=(const thrust::tuple<Us...>& other)
{
super_t::operator=(other);
return *this;
}

// allow assignment from pairs
// XXX might be worthwhile to guard this with an enable_if is_assignable
_CCCL_EXEC_CHECK_DISABLE
template <typename U1, typename U2>
inline _CCCL_HOST_DEVICE tuple_of_iterator_references& operator=(const thrust::pair<U1, U2>& other)
{
super_t::operator=(other);
return *this;
}

// allow assignment from reference<tuple>
// XXX perhaps we should generalize to reference<T>
// we could captures reference<pair> this way
_CCCL_EXEC_CHECK_DISABLE
template <typename Pointer, typename Derived, typename... Us>
inline _CCCL_HOST_DEVICE tuple_of_iterator_references&
operator=(const thrust::reference<thrust::tuple<Us...>, Pointer, Derived>& other)
{
typedef thrust::tuple<Us...> tuple_type;

// XXX perhaps this could be accelerated
super_t::operator=(tuple_type{other});
return *this;
}

template <class... Us, ::cuda::std::__enable_if_t<sizeof...(Us) == sizeof...(Ts), int> = 0>
inline _CCCL_HOST_DEVICE constexpr operator thrust::tuple<Us...>() const
{
return __to_tuple<Us...>(typename ::cuda::std::__make_tuple_indices<sizeof...(Ts)>::type{});
}

// this overload of swap() permits swapping tuple_of_iterator_references returned as temporaries from
// iterator dereferences
template <class... Us>
inline _CCCL_HOST_DEVICE friend void swap(tuple_of_iterator_references&& x, tuple_of_iterator_references<Us...>&& y)
{
x.swap(y);
}

template <class... Us, size_t... Id>
inline _CCCL_HOST_DEVICE constexpr thrust::tuple<Us...> __to_tuple(::cuda::std::__tuple_indices<Id...>) const
{
return {maybe_unwrap_nested<Us, Ts>{}(get<Id>(*this))...};
}
};

} // end detail
} // namespace detail

THRUST_NAMESPACE_END

Expand All @@ -145,7 +152,8 @@ struct tuple_element<Id, THRUST_NS_QUALIFIER::detail::tuple_of_iterator_referenc
_LIBCUDACXX_END_NAMESPACE_STD

// structured bindings suppport
namespace std {
namespace std
{

template <class... Ts>
struct tuple_size<THRUST_NS_QUALIFIER::detail::tuple_of_iterator_references<Ts...>>
Expand Down

0 comments on commit ec9840c

Please sign in to comment.