From 8ac0da3dcdebbb71315a203ace2a5d5bdb48c8a0 Mon Sep 17 00:00:00 2001 From: Glen Cao Date: Fri, 20 Jun 2025 18:25:41 -0700 Subject: [PATCH] Incorporated the fix made by Doru Bercea from the PR https://github.com/pytorch/pytorch/pull/155806 into ROCm PyTorch release 2.7 --- aten/src/ATen/native/cuda/Reduce.cuh | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index ad2588e181ed9..15a572804af5f 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1118,13 +1118,19 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ int max_threads_per_mp = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor; #ifdef USE_ROCM - // Control the number of threadblocks by adjusting the maximum number of - // threads per multi-processor. These numbers better reflect the maximum - // theoretical achievable threads per MP for the reduction operation. - if (iter.ndim() == 1 || iter.ndim() == 3) - max_threads_per_mp = 512; - if (iter.ndim() == 2) - max_threads_per_mp = 256; + // If the grid consists of a single threadblock, do not change the max threads per + // MP value. This will increase the parallelism across the y dimension of the grid. + bool uses_a_single_block = config.grid().x == config.grid().y == config.grid().z == 1; + + if (!uses_a_single_block) { + // Control the number of threadblocks by adjusting the maximum number of + // threads per multi-processor. These numbers better reflect the maximum + // theoretical achievable threads per MP for the reduction operation. + if (iter.ndim() == 1 || iter.ndim() == 3) + max_threads_per_mp = 512; + else if (iter.ndim() == 2) + max_threads_per_mp = 256; + } #endif const int blocks_per_sm = max_threads_per_mp / config.num_threads; const int target_grid_size = num_mp * blocks_per_sm;