From f2fc6731205f231754df9b272ef3a2765c57cea5 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Fri, 22 Aug 2025 09:43:52 -0700 Subject: [PATCH] [ROCm] Unroll loads in global_reduce cherry-pick of https://github.com/pytorch/pytorch/pull/161181 --- aten/src/ATen/native/cuda/Reduce.cuh | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 7d1c45e785b79..7cc71711d01d6 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -822,6 +822,23 @@ struct ReduceOp { } else { index_t input_offset = threadIdx.y; index_t step = blockDim.y; +#ifdef USE_ROCM // Prefetch loads to better hide their latency + #define PRFCH 4 + for (; input_offset < config.ctas_per_output; input_offset += step*PRFCH) { + arg_vec_t next[PRFCH]; + #pragma unroll + for (int u = 0; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output); u++) { + index_t idx = config.staging_memory_offset(input_offset + u*step); + next[u] = reduce_buffer[idx]; + } + for (int u = 0; (u < PRFCH) && (input_offset + u*step < config.ctas_per_output); u++) { + #pragma unroll + for (int i = 0; i < output_vec_size; i++) { + value[i] = ops.combine(value[i], next[u][i]); + } + } + } +#else for (; input_offset < config.ctas_per_output; input_offset += step) { index_t idx = config.staging_memory_offset(input_offset); arg_vec_t next = reduce_buffer[idx]; @@ -830,6 +847,7 @@ struct ReduceOp { value[i] = ops.combine(value[i], next[i]); } } +#endif } value = block_y_reduce(value, shared_memory); if (config.should_block_x_reduce()) {