Skip to content
Permalink
Browse files

Added output iterator adapter for blaze/thrust interoperability

  • Loading branch information...
JPenuchot committed Oct 30, 2019
1 parent 16380e1 commit 8de4bc0e08145178e672ea59a3edd18924ebe64b
Showing with 56 additions and 11 deletions.
  1. +56 −11 blaze_cuda/util/algorithms/CUDATransform.h
@@ -51,7 +51,7 @@ namespace blaze {
namespace detail {

template< typename IteratorType >
class ThrustIteratorAdapter
class ThrustInputIteratorAdapter
{
IteratorType it;

@@ -72,24 +72,67 @@ class ThrustIteratorAdapter
using difference_type = DifferenceType; //!< Difference between two iterators.

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE
ThrustIteratorAdapter( IteratorType const& it ): it(it) {}
ThrustInputIteratorAdapter( IteratorType const& it ): it(it) {}

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE ValueType operator[]( ptrdiff_t inc ) const noexcept {
return *ThrustIteratorAdapter( it + inc );
return *ThrustInputIteratorAdapter( it + inc );
}

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE ValueType operator*() const noexcept {
return *it;
}

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE auto
operator-( ThrustIteratorAdapter const& other ) const noexcept
operator-( ThrustInputIteratorAdapter const& other ) const noexcept
{
return it - other.it;
return ThrustInputIteratorAdapter( it - other.it );
}

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE auto operator+( ptrdiff_t inc ) const noexcept {
return IteratorType( it + inc );
return ThrustInputIteratorAdapter( IteratorType( it + inc ) );
}
};

template< typename IteratorType >
class ThrustOutputIteratorAdapter
{
IteratorType it;

public:

//**Type definitions****************************************************************************
using IteratorCategory = typename IteratorType::IteratorCategory; //!< The iterator category.
using ValueType = typename IteratorType::ValueType; //!< Type of the underlying elements.
using PointerType = typename IteratorType::PointerType; //!< Pointer return type.
using ReferenceType = typename IteratorType::ReferenceType; //!< Reference return type.
using DifferenceType = typename IteratorType::DifferenceType; //!< Difference between two iterators.

// STL iterator requirements
using iterator_category = IteratorCategory; //!< The iterator category.
using value_type = ValueType; //!< Type of the underlying elements.
using pointer = PointerType; //!< Pointer return type.
using reference = ReferenceType; //!< Reference return type.
using difference_type = DifferenceType; //!< Difference between two iterators.

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE ThrustOutputIteratorAdapter( IteratorType const& it )
: it(it) {}

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE ReferenceType operator[]( ptrdiff_t inc ) const noexcept {
return *ThrustOutputIteratorAdapter( it + inc );
}

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE ReferenceType operator*() const noexcept {
return *it;
}

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE auto
operator-( ThrustOutputIteratorAdapter const& other ) noexcept
{
return ThrustOutputIteratorAdapter( it - other.it );
}

BLAZE_ALWAYS_INLINE BLAZE_DEVICE_CALLABLE auto operator+( ptrdiff_t inc ) noexcept {
return ThrustOutputIteratorAdapter( IteratorType( it + inc ) );
}
};

@@ -104,12 +147,14 @@ inline void cuda_transform ( InputIt1 in1_begin , InputIt1 in1_end
, F f )
{
using namespace detail;
using AI2 = ThrustIteratorAdapter<InputIt2>;
using AI1 = ThrustInputIteratorAdapter<InputIt1>;
using AI2 = ThrustInputIteratorAdapter<InputIt2>;
using AO = ThrustOutputIteratorAdapter<OutputIt>;

thrust::transform( thrust::device,
in1_begin, in1_end, // Meant to be the left-hand side
AI2( in2_begin ), // Adaptor for the right-hand side
out_begin, f );
AI1( in1_begin ), AI1( in1_end ), // Meant to be the left-hand side
AI2( in2_begin ), // Adaptor for the right-hand side
AO( out_begin ), f );
}

template < std::size_t Unroll = 16
@@ -120,7 +165,7 @@ inline void cuda_transform ( InputIt1 in1_begin , InputIt1 in1_end
, F f )
{
using namespace detail;
using AI1 = ThrustIteratorAdapter<InputIt1>;
using AI1 = ThrustInputIteratorAdapter<InputIt1>;

thrust::transform( thrust::device, AI1(in1_begin), AI1(in1_end), out_begin, f );
}

0 comments on commit 8de4bc0

Please sign in to comment.
You can’t perform that action at this time.