Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion aten/src/ATen/native/cuda/Reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ struct ReduceConfig {
int values_per_thread() const {
return div_up(num_inputs, step_input);
}

int mock_values_per_thread(int parallelism) {
return div_up(num_inputs, step_input * parallelism);
}
};

std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
Expand Down Expand Up @@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
else if (config.ctas_per_output < 16)
config.ctas_per_output = 1;
bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) {
config.ctas_per_output = 4;
int vpt = config.values_per_thread();
// Capping the number of values per thread to 2048 for now
// based on known use cases.
while (vpt >= 2048) {
config.ctas_per_output *= 2;
// Computes the new values per thread without side effects
vpt = config.mock_values_per_thread(config.ctas_per_output);
}
}
#endif
if (config.ctas_per_output > 1) {
config.input_mult[2] = config.split_input(config.ctas_per_output);
Expand Down