diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 8380530646175..0e3fc88b569c0 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -306,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel( stat_accscalar_t var_n = 0; int n = 0; for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { +#if defined(USE_ROCM) + constexpr int UNRL = 4; + stat_accscalar_t v_[UNRL]; + for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) { + for (int u = 0; u < UNRL; u++) + v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)]; + for (int u = 0; u < UNRL; u++) { + if (x+u*blockDim.x < input.size(2)) { + stat_accscalar_t d1 = v_[u] - avg; + n++; + avg += d1 / n; + var_n += d1 * (v_[u] - avg); + } + } + } +#else for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { stat_accscalar_t v = input[batch][plane][x]; stat_accscalar_t d1 = v - avg; @@ -313,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel( avg += d1 / n; var_n += d1 * (v - avg); } +#endif } // first warpSum to get one value per thread to