Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,14 @@ pub(crate) struct GroupedHashAggregateStream {
/// max rows in output RecordBatches
batch_size: usize,

/// Max rows per emitted output batch. Matches `batch_size` for most
/// modes, but for `Partial` / `PartialReduce` — whose output feeds
/// a hash repartition that will split every input batch into `P`
/// sub-batches — we scale this to `batch_size * P` so the
/// post-repartition sub-batches already land at ~`batch_size` and
/// a downstream `CoalesceBatchesExec` has nothing to do.
emit_batch_size: usize,

/// Optional soft limit on the number of `group_values` in a batch
/// If the number of `group_values` in a single batch exceeds this value,
/// the `GroupedHashAggregateStream` operation immediately switches to
Expand Down Expand Up @@ -470,6 +478,11 @@ impl GroupedHashAggregateStream {
let agg_filter_expr = Arc::clone(&agg.filter_expr);

let batch_size = context.session_config().batch_size();
let emit_batch_size = match agg.mode {
AggregateMode::Partial | AggregateMode::PartialReduce => batch_size
.saturating_mul(context.session_config().target_partitions().max(1)),
_ => batch_size,
};
let input = agg.input.execute(partition, Arc::clone(context))?;
let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
let group_by_metrics = GroupByMetrics::new(&agg.metrics, partition);
Expand Down Expand Up @@ -675,6 +688,7 @@ impl GroupedHashAggregateStream {
baseline_metrics,
group_by_metrics,
batch_size,
emit_batch_size,
group_ordering,
input_done: false,
spill_state,
Expand Down Expand Up @@ -842,7 +856,7 @@ impl Stream for GroupedHashAggregateStream {
ExecutionState::ProducingOutput(batch) => {
// slice off a part of the batch, if needed
let output_batch;
let size = self.batch_size;
let size = self.emit_batch_size;
(self.exec_state, output_batch) = if batch.num_rows() <= size {
(
if self.input_done {
Expand All @@ -860,8 +874,7 @@ impl Stream for GroupedHashAggregateStream {
batch.clone(),
)
} else {
// output first batch_size rows
let size = self.batch_size;
// output first `emit_batch_size` rows
let num_remaining = batch.num_rows() - size;
let remaining = batch.slice(size, num_remaining);
let output = batch.slice(0, size);
Expand Down
Loading