From 510e65c85c0d045d84d5906b739f7d2d4ea21ba6 Mon Sep 17 00:00:00 2001 From: Jerry Mannil Date: Wed, 21 May 2025 23:15:11 +0000 Subject: [PATCH] [ROCm] Fix 3D tensor perf degradation with NHWC format --- aten/src/ATen/native/cuda/Reduce.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index c5a335e299614..3d4fa34735b8f 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1156,7 +1156,8 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ config.ctas_per_output = div_up(num_mp, 2); else if (config.ctas_per_output < 16) config.ctas_per_output = 1; - if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension) + bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); + if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) config.ctas_per_output = 4; #endif if (config.ctas_per_output > 1) {