diff --git a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh index 60e1a19c1aac..a65db3f2df12 100644 --- a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh +++ b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh @@ -45,6 +45,24 @@ struct OffsetCalculator { C10_HOST_DEVICE offset_type get(index_t linear_idx) const { offset_type offsets; + +#if defined(USE_ROCM) + if ((dims > 0) && (dims <= 2)) { + auto divmod = sizes_[0].divmod(linear_idx); + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) + offsets[arg] = divmod.mod * strides_[0][arg]; + if (dims >= 2) { + divmod = sizes_[1].divmod(divmod.div); + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) + offsets[arg] += divmod.mod * strides_[1][arg]; + } + // [...] + return offsets; + } +#endif + #pragma unroll for (int arg = 0; arg < NARGS; arg++) { offsets[arg] = 0;