From 64d8fc1c92ee0707dca4d9fada828b0b89c96e5d Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Sun, 22 Jun 2025 21:18:56 -0700 Subject: [PATCH] [ROCm] Enable more parallelism for multi-dimensional reductions (#2291) cherry-pick of https://github.com/pytorch/pytorch/commit/085f270a00b4452bbb005d6b3d448e9d0b9d6fa0 in rocm/pytorch:release/2.7 Co-authored-by: Doru Bercea, Glen Cao --- aten/src/ATen/native/cuda/Reduce.cuh | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 3d4fa34735b8f..d84df6e55837a 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1115,13 +1115,19 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ int max_threads_per_mp = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor; #ifdef USE_ROCM - // Control the number of threadblocks by adjusting the maximum number of - // threads per multi-processor. These numbers better reflect the maximum - // theoretical achievable threads per MP for the reduction operation. - if (iter.ndim() == 1 || iter.ndim() == 3) - max_threads_per_mp = 512; - if (iter.ndim() == 2) - max_threads_per_mp = 256; + // If the grid consists of a single threadblock, do not change the max threads per + // MP value. This will increase the parallelism across the y dimension of the grid. + bool uses_a_single_block = config.grid().x == config.grid().y == config.grid().z == 1; + + if (!uses_a_single_block) { + // Control the number of threadblocks by adjusting the maximum number of + // threads per multi-processor. These numbers better reflect the maximum + // theoretical achievable threads per MP for the reduction operation. + if (iter.ndim() == 1 || iter.ndim() == 3) + max_threads_per_mp = 512; + else if (iter.ndim() == 2) + max_threads_per_mp = 256; + } #endif const int blocks_per_sm = max_threads_per_mp / config.num_threads; const int target_grid_size = num_mp * blocks_per_sm;