diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 15a572804af5f..521b467480900 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -209,6 +209,10 @@ struct ReduceConfig { int values_per_thread() const { return div_up(num_inputs, step_input); } + + int mock_values_per_thread(int parallelism) { + return div_up(num_inputs, step_input * parallelism); + } }; std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); @@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ else if (config.ctas_per_output < 16) config.ctas_per_output = 1; bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); - if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) + if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) { config.ctas_per_output = 4; + int vpt = config.values_per_thread(); + // Capping the number of values per thread to 2048 for now + // based on known use cases. + while (vpt >= 2048) { + config.ctas_per_output *= 2; + // Computes the new values per thread without side effects + vpt = config.mock_values_per_thread(config.ctas_per_output); + } + } #endif if (config.ctas_per_output > 1) { config.input_mult[2] = config.split_input(config.ctas_per_output);