diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index c5a335e299614..3d4fa34735b8f 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1156,7 +1156,8 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ config.ctas_per_output = div_up(num_mp, 2); else if (config.ctas_per_output < 16) config.ctas_per_output = 1; - if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension) + bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); + if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) config.ctas_per_output = 4; #endif if (config.ctas_per_output > 1) {