diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 3c6577af4286..f2b7bc0ebc02 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -684,8 +684,6 @@ impl Stream for GroupedHashAggregateStream {
// Do the grouping
self.group_aggregate_batch(batch)?;
- self.update_skip_aggregation_probe(input_rows);
-
// If we can begin emitting rows, do so,
// otherwise keep consuming input
assert!(!self.input_done);
@@ -709,9 +707,22 @@ impl Stream for GroupedHashAggregateStream {
break 'reading_input;
}
- self.emit_early_if_necessary()?;
+ // Check if we should switch to skip aggregation mode
+ // It's important that we do this before we early emit since we've
+ // already updated the probe.
+ self.update_skip_aggregation_probe(input_rows);
+ if let Some(new_state) = self.switch_to_skip_aggregation()? {
+ timer.done();
+ self.exec_state = new_state;
+ break 'reading_input;
+ }
- self.switch_to_skip_aggregation()?;
+ // Check if we need to emit early due to memory pressure
+ if let Some(new_state) = self.emit_early_if_necessary()? {
+ timer.done();
+ self.exec_state = new_state;
+ break 'reading_input;
+ }
timer.done();
}
@@ -785,6 +796,15 @@ impl Stream for GroupedHashAggregateStream {
}
None => {
// inner is done, switching to `Done` state
+ // Sanity check: when switching from SkippingAggregation to Done,
+ // all groups should have already been emitted
+ if !self.group_values.is_empty() {
+ return Poll::Ready(Some(internal_err!(
+ "Switching from SkippingAggregation to Done with {} groups still in hash table. \
+ This is a bug - all groups should have been emitted before skip aggregation started.",
+ self.group_values.len()
+ )));
+ }
self.exec_state = ExecutionState::Done;
}
}
@@ -832,6 +852,13 @@ impl Stream for GroupedHashAggregateStream {
}
ExecutionState::Done => {
+ // Sanity check: all groups should have been emitted by now
+ if !self.group_values.is_empty() {
+ return Poll::Ready(Some(internal_err!(
+ "AggregateStream was in Done state with {} groups left in hash table. \
+ This is a bug - all groups should have been emitted before entering Done state.",
+ self.group_values.len())));
+ }
// release the memory reservation since sending back output batch itself needs
// some memory reservation, so make some room for it.
self.clear_all();
@@ -1096,7 +1123,9 @@ impl GroupedHashAggregateStream {
/// Emit if the used memory exceeds the target for partial aggregation.
/// Currently only [`GroupOrdering::None`] is supported for early emitting.
/// TODO: support group_ordering for early emitting
- fn emit_early_if_necessary(&mut self) -> Result<()> {
+ ///
+ /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise.
+ fn emit_early_if_necessary(&mut self) -> Result