From 1b5dc89042c4b7804494022371ae5c4ae299912f Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Thu, 23 Oct 2025 09:10:55 -0700 Subject: [PATCH] [ROCm] Deserialize loads in planer sum portion of reduce() of norm (#2740) cherry-pick of https://github.com/pytorch/pytorch/commit/6b7cd48e7eee74e4e565f4f3c14bf2d57ec41858 Fixes #SWDEV-561122 --- aten/src/ATen/native/cuda/Normalization.cuh | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 5e6e59784ef39..8380530646175 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { // first the reductions each thread does separately scalar_t sum = static_cast(0); for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { +#if defined(USE_ROCM) + constexpr int UNRL = 4; // load deserilize factor + scalar_t tmp[UNRL]; + for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) { +#pragma unroll + for (int u = 0; u < UNRL; u++) + tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x))); +#pragma unroll + for (int u = 0; u < UNRL; u++) + if (x+u*blockDim.x < tensor.size(2)) + sum += tmp[u]; + } +#else for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { sum += op(batch, plane, x); } +#endif } __shared__ scalar_t shared[C10_WARP_SIZE]; SumReduceOp reduce_op;