From 33deb83483e7827f75996182d80117661610a417 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Wed, 6 Aug 2025 11:13:40 -0700 Subject: [PATCH] [ROCm] Limit number of values per thread for reductions on three dimensions (#2460) In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high. cherry-pick of https://github.com/pytorch/pytorch/pull/159652 --- aten/src/ATen/native/cuda/Reduce.cuh | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 15a572804af5f..521b467480900 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -209,6 +209,10 @@ struct ReduceConfig { int values_per_thread() const { return div_up(num_inputs, step_input); } + + int mock_values_per_thread(int parallelism) { + return div_up(num_inputs, step_input * parallelism); + } }; std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); @@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ else if (config.ctas_per_output < 16) config.ctas_per_output = 1; bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); - if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) + if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) { config.ctas_per_output = 4; + int vpt = config.values_per_thread(); + // Capping the number of values per thread to 2048 for now + // based on known use cases. + while (vpt >= 2048) { + config.ctas_per_output *= 2; + // Computes the new values per thread without side effects + vpt = config.mock_values_per_thread(config.ctas_per_output); + } + } #endif if (config.ctas_per_output > 1) { config.input_mult[2] = config.split_input(config.ctas_per_output);