From 5b0b711724fca9946d49e443e76c71b80735f6a6 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Fri, 24 Oct 2025 11:08:22 -0700 Subject: [PATCH] Closed [ROCm] deserialize loads in planer sum portion of stats() of norm cherry-pick of https://github.com/pytorch/pytorch/commit/47f638eae7f6a26b6c8ea5625551c986aabd87c7 --- aten/src/ATen/native/cuda/Normalization.cuh | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 8380530646175..0e3fc88b569c0 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -306,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel( stat_accscalar_t var_n = 0; int n = 0; for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { +#if defined(USE_ROCM) + constexpr int UNRL = 4; + stat_accscalar_t v_[UNRL]; + for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) { + for (int u = 0; u < UNRL; u++) + v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)]; + for (int u = 0; u < UNRL; u++) { + if (x+u*blockDim.x < input.size(2)) { + stat_accscalar_t d1 = v_[u] - avg; + n++; + avg += d1 / n; + var_n += d1 * (v_[u] - avg); + } + } + } +#else for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { stat_accscalar_t v = input[batch][plane][x]; stat_accscalar_t d1 = v - avg; @@ -313,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel( avg += d1 / n; var_n += d1 * (v - avg); } +#endif } // first warpSum to get one value per thread to