From 0ca7fb9519afef7dac7956a0552e711bd47062bb Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:21:07 -0700 Subject: [PATCH] [ROCm] Improve perf for elementwise broadcast with mixed dtype * cherry-pick of https://github.com/pytorch/pytorch/commit/2aadcea05c3c57eaf87cd90e018c1ee0379e7e88 --- aten/src/ATen/native/cuda/CUDALoops.cuh | 29 +++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index f96b8d687bc8..609393530eac 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -999,12 +999,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { dtypes[i] = iter.dtype(i); } auto offset_calc = ::make_offset_calculator(iter); +#ifdef USE_ROCM + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx + grp_sz); + auto offsets2 = offset_calc.get(idx + grp_sz * 2); + auto offsets3 = offset_calc.get(idx + grp_sz * 3); + void* out0 = data[0] + offsets0[0]; + void* out1 = data[0] + offsets1[0]; + void* out2 = data[0] + offsets2[0]; + void* out3 = data[0] + offsets3[0]; + arg0_t result0 = invoke(f, &data[1], &offsets0[1], &dtypes[1], 1); + arg0_t result1 = invoke(f, &data[1], &offsets1[1], &dtypes[1], 1); + arg0_t result2 = invoke(f, &data[1], &offsets2[1], &dtypes[1], 1); + arg0_t result3 = invoke(f, &data[1], &offsets3[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out0, result0); + c10::cast_and_store(dtypes[0], out1, result1); + c10::cast_and_store(dtypes[0], out2, result2); + c10::cast_and_store(dtypes[0], out3, result3); + } else { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out, result); + } + }); +#else launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { auto offsets = offset_calc.get(idx); void* out = data[0] + offsets[0]; arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); c10::cast_and_store(dtypes[0], out, result); }); +#endif } }