From 33d172bf9c9078514b47a93c32b42d14d50859ad Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Wed, 3 Sep 2025 11:45:19 -0700 Subject: [PATCH] [ROCm] OffsetCalc Unroll Optimization (#2597) cherry-pick of https://github.com/pytorch/pytorch/pull/161700 Our compiler is generating inefficient code for the offsetCalc in certain situations. The root-cause for this needs to be identified. For now specialized unrolling based on 'dims' notably helps perf. Fixes SWDEV-545713, SWDEV-545710 --- aten/src/ATen/cuda/detail/OffsetCalculator.cuh | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh index 60e1a19c1aac..a65db3f2df12 100644 --- a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh +++ b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh @@ -45,6 +45,24 @@ struct OffsetCalculator { C10_HOST_DEVICE offset_type get(index_t linear_idx) const { offset_type offsets; + +#if defined(USE_ROCM) + if ((dims > 0) && (dims <= 2)) { + auto divmod = sizes_[0].divmod(linear_idx); + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) + offsets[arg] = divmod.mod * strides_[0][arg]; + if (dims >= 2) { + divmod = sizes_[1].divmod(divmod.div); + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) + offsets[arg] += divmod.mod * strides_[1][arg]; + } + // [...] + return offsets; + } +#endif + #pragma unroll for (int arg = 0; arg < NARGS; arg++) { offsets[arg] = 0;