diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 5e6e59784ef39..8380530646175 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { // first the reductions each thread does separately scalar_t sum = static_cast(0); for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { +#if defined(USE_ROCM) + constexpr int UNRL = 4; // load deserilize factor + scalar_t tmp[UNRL]; + for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) { +#pragma unroll + for (int u = 0; u < UNRL; u++) + tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x))); +#pragma unroll + for (int u = 0; u < UNRL; u++) + if (x+u*blockDim.x < tensor.size(2)) + sum += tmp[u]; + } +#else for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { sum += op(batch, plane, x); } +#endif } __shared__ scalar_t shared[C10_WARP_SIZE]; SumReduceOp reduce_op;