forked from pytorch/pytorch
/
find_op.cu
56 lines (48 loc) · 1.45 KB
/
find_op.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <cub/block/block_reduce.cuh>
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/find_op.h"
namespace caffe2 {
template <typename T>
__global__ void FindKernel(
int num_needles,
int idx_size,
const T* idx,
const T* needles,
int* out,
int missing_value) {
int needle_idx = blockIdx.x; // One cuda block per needle
T q = needles[needle_idx];
int res = (-1);
for (int j = threadIdx.x; j < idx_size; j += CAFFE_CUDA_NUM_THREADS) {
if (idx[j] == q) {
res = max(res, j);
}
}
typedef cub::BlockReduce<int, CAFFE_CUDA_NUM_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int min_res = BlockReduce(temp_storage).Reduce(res, cub::Max());
if (threadIdx.x == 0) {
out[needle_idx] = min_res == (-1) ? missing_value : min_res;
}
}
template <>
template <typename T>
bool FindOp<CUDAContext>::DoRunWithType() {
auto& idx = Input(0);
auto& needles = Input(1);
auto* res_indices = Output(0, needles.sizes(), at::dtype<int>());
const T* idx_data = idx.data<T>();
const T* needles_data = needles.data<T>();
int* res_data = res_indices->template mutable_data<int>();
FindKernel<
T><<<needles.numel(), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
needles.numel(),
idx.numel(),
idx_data,
needles_data,
res_data,
missing_value_);
return true;
}
REGISTER_CUDA_OPERATOR(Find, FindOp<CUDAContext>)
} // namespace caffe2