@@ -90,20 +90,20 @@ __global__ void roll_cuda_kernel(
9090 int64_t size,
9191 int64_t stride,
9292 int64_t total_dims) {
93- int64_t linear_index = ((int64_t ) blockIdx .x ) * blockDim .x + threadIdx .x ;
94- if (linear_index >= N) {
95- return ;
96- }
97- // roll dim idx is the index of linear_index along the rolling dimension.
98- int64_t roll_dim_idx = linear_index % (stride * size) / stride;
99- // index into the source data to find appropriate value.
100- int64_t source_idx = 0 ;
101- if ( roll_dim_idx >= (size - start) ) {
102- source_idx = linear_index - ((size - start) * stride);
103- } else {
104- source_idx = linear_index + (start * stride);
93+ for (int64_t linear_index = ((int64_t ) blockIdx .x ) * blockDim .x + threadIdx .x ;
94+ linear_index < N; linear_index += blockDim .x *gridDim .x )
95+ {
96+ // roll dim idx is the index of linear_index along the rolling dimension.
97+ int64_t roll_dim_idx = linear_index % (stride * size) / stride;
98+ // index into the source data to find appropriate value.
99+ int64_t source_idx = 0 ;
100+ if ( roll_dim_idx >= (size - start) ) {
101+ source_idx = linear_index - ((size - start) * stride);
102+ } else {
103+ source_idx = linear_index + (start * stride);
104+ }
105+ out_tensor[linear_index] = in_tensor[source_idx];
105106 }
106- out_tensor[linear_index] = in_tensor[source_idx];
107107}
108108
109109// Roll a tensor along a dimension
@@ -129,8 +129,10 @@ Tensor roll_cuda(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
129129 if ( start < 0 ) start = start + size;
130130
131131 dim3 dim_block = cuda::getApplyBlock ();
132- dim3 dim_grid;
133- TORCH_CHECK (cuda::getApplyGrid (N, dim_grid, in_tensor.get_device ()), " unable to get dim grid" );
132+
133+ const int num_mp = at::cuda::getCurrentDeviceProperties ()->multiProcessorCount ;
134+ // Given a thread block size of 512, we launch with 4 blocks per SM/CU
135+ dim3 dim_grid (num_mp * 4 );
134136
135137 auto total_dims = in_tensor.dim ();
136138
0 commit comments