From b31274f527974938f03d4f4b5e9375c3345154e1 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Tue, 12 Aug 2025 19:16:08 -0700 Subject: [PATCH] [ROCm] Improve reduction sum performance (#2492) * Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 0 **Reproducer:** ``` import time import torch shapes = [ (5079670, 128) ] dims = [ (1) ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.float) for _ in range(10): w = torch.sum(x, dims[i]) torch.cuda.synchronize() print(w.size()) start_time = time.time() for _ in range(50): _ = torch.sum(x, dims[i]) torch.cuda.synchronize() end_time = time.time() mean_time = (end_time - start_time)/50 print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us") ``` **Before (MI300X):** Avg time for shape (5079670, 128): 1629.99 us **After (MI300X)** Avg time for shape (5079670, 128): 1008.59 us cherry-pick of https://github.com/pytorch/pytorch/pull/160466 Fixes SWDEV-546136 --- aten/src/ATen/native/cuda/Reduce.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 521b467480900..7d1c45e785b79 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1062,7 +1062,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ // In such case, values in each loaded vector always correspond to different outputs. if (fastest_moving_stride == sizeof(scalar_t)) { #ifdef USE_ROCM - if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1) { + if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) { #else if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) { #endif