From 6466e8d7c096c7d9d4e695b485821a54eda5373a Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Thu, 4 Dec 2025 09:38:47 -0800 Subject: [PATCH] roll kernel as grid stride loop cherry-pick of https://github.com/pytorch/pytorch/pull/169474 --- .../ATen/native/cuda/TensorTransformations.cu | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index c1c5c4399c051..7d2da04e5fc6f 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -90,20 +90,20 @@ __global__ void roll_cuda_kernel( int64_t size, int64_t stride, int64_t total_dims) { - int64_t linear_index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; - if (linear_index >= N) { - return; - } - // roll dim idx is the index of linear_index along the rolling dimension. - int64_t roll_dim_idx = linear_index % (stride * size) / stride; - // index into the source data to find appropriate value. - int64_t source_idx = 0; - if( roll_dim_idx >= (size - start) ) { - source_idx = linear_index - ((size - start) * stride); - } else { - source_idx = linear_index + (start * stride); + for (int64_t linear_index = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; + linear_index < N; linear_index += blockDim.x*gridDim.x) + { + // roll dim idx is the index of linear_index along the rolling dimension. + int64_t roll_dim_idx = linear_index % (stride * size) / stride; + // index into the source data to find appropriate value. + int64_t source_idx = 0; + if( roll_dim_idx >= (size - start) ) { + source_idx = linear_index - ((size - start) * stride); + } else { + source_idx = linear_index + (start * stride); + } + out_tensor[linear_index] = in_tensor[source_idx]; } - out_tensor[linear_index] = in_tensor[source_idx]; } // Roll a tensor along a dimension @@ -129,8 +129,10 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) { if( start < 0 ) start = start + size; dim3 dim_block = cuda::getApplyBlock(); - dim3 dim_grid; - TORCH_CHECK(cuda::getApplyGrid(N, dim_grid, in_tensor.get_device()), "unable to get dim grid"); + + const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + // Given a thread block size of 512, we launch with 4 blocks per SM/CU + dim3 dim_grid(num_mp * 4); auto total_dims = in_tensor.dim();