-
Notifications
You must be signed in to change notification settings - Fork 6
/
strided_reduction.cu
81 lines (60 loc) · 2.16 KB
/
strided_reduction.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/functional.h>
#include <thrust/fill.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
// for printing
#include <thrust/copy.h>
#include <ostream>
#define STRIDE 2
template <typename Iterator>
class strided_range
{
public:
typedef typename thrust::iterator_difference<Iterator>::type difference_type;
struct stride_functor : public thrust::unary_function<difference_type,difference_type>
{
difference_type stride;
stride_functor(difference_type stride)
: stride(stride) {}
__host__ __device__
difference_type operator()(const difference_type& i) const
{
return stride * i;
}
};
typedef typename thrust::counting_iterator<difference_type> CountingIterator;
typedef typename thrust::transform_iterator<stride_functor, CountingIterator> TransformIterator;
typedef typename thrust::permutation_iterator<Iterator,TransformIterator> PermutationIterator;
// type of the strided_range iterator
typedef PermutationIterator iterator;
// construct strided_range for the range [first,last)
strided_range(Iterator first, Iterator last, difference_type stride)
: first(first), last(last), stride(stride) {}
iterator begin(void) const
{
return PermutationIterator(first, TransformIterator(CountingIterator(0), stride_functor(stride)));
}
iterator end(void) const
{
return begin() + ((last - first) + (stride - 1)) / stride;
}
protected:
Iterator first;
Iterator last;
difference_type stride;
};
int main(void)
{
const int N = 9;
thrust::host_vector<int> h_data(N);
for (int i=0; i<N; i++) h_data[i] = i;
thrust::device_vector<int> data(h_data);
typedef thrust::device_vector<int>::iterator Iterator;
strided_range<Iterator> pos(data.begin(), data.end(), STRIDE);
int sum = thrust::reduce(pos.begin(), pos.end(), (int) 0, thrust::plus<int>());
printf("sum = %i\n",sum);
return 0;
}