Skip to content

Commit a47ec2b

Browse files
roll kernel as grid stride loop
Cherry-pick of #2852 Co-authored-by: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com>
1 parent bd60f20 commit a47ec2b

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

aten/src/ATen/native/cuda/TensorTransformations.cu

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,20 @@ __global__ void roll_cuda_kernel(
9090
int64_t size,
9191
int64_t stride,
9292
int64_t total_dims) {
93-
int64_t linear_index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
94-
if (linear_index >= N) {
95-
return;
96-
}
97-
// roll dim idx is the index of linear_index along the rolling dimension.
98-
int64_t roll_dim_idx = linear_index % (stride * size) / stride;
99-
// index into the source data to find appropriate value.
100-
int64_t source_idx = 0;
101-
if( roll_dim_idx >= (size - start) ) {
102-
source_idx = linear_index - ((size - start) * stride);
103-
} else {
104-
source_idx = linear_index + (start * stride);
93+
for (int64_t linear_index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
94+
linear_index < N; linear_index += blockDim.x*gridDim.x)
95+
{
96+
// roll dim idx is the index of linear_index along the rolling dimension.
97+
int64_t roll_dim_idx = linear_index % (stride * size) / stride;
98+
// index into the source data to find appropriate value.
99+
int64_t source_idx = 0;
100+
if( roll_dim_idx >= (size - start) ) {
101+
source_idx = linear_index - ((size - start) * stride);
102+
} else {
103+
source_idx = linear_index + (start * stride);
104+
}
105+
out_tensor[linear_index] = in_tensor[source_idx];
105106
}
106-
out_tensor[linear_index] = in_tensor[source_idx];
107107
}
108108

109109
// Roll a tensor along a dimension
@@ -129,8 +129,10 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
129129
if( start < 0 ) start = start + size;
130130

131131
dim3 dim_block = cuda::getApplyBlock();
132-
dim3 dim_grid;
133-
TORCH_CHECK(cuda::getApplyGrid(N, dim_grid, in_tensor.get_device()), "unable to get dim grid");
132+
133+
const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
134+
// Given a thread block size of 512, we launch with 4 blocks per SM/CU
135+
dim3 dim_grid(num_mp * 4);
134136

135137
auto total_dims = in_tensor.dim();
136138

0 commit comments

Comments
 (0)