Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 161 additions & 10 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,6 @@ impl Stream for GroupedHashAggregateStream {
// Do the grouping
self.group_aggregate_batch(batch)?;

self.update_skip_aggregation_probe(input_rows);

// If we can begin emitting rows, do so,
// otherwise keep consuming input
assert!(!self.input_done);
Expand All @@ -709,9 +707,22 @@ impl Stream for GroupedHashAggregateStream {
break 'reading_input;
}

self.emit_early_if_necessary()?;
// Check if we should switch to skip aggregation mode
// It's important that we do this before we early emit since we've
// already updated the probe.
self.update_skip_aggregation_probe(input_rows);
if let Some(new_state) = self.switch_to_skip_aggregation()? {
Copy link
Contributor Author

@xanderbailey xanderbailey Nov 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me why we set update_skip_aggregation_probe here https://github.com/apache/datafusion/pull/18712/files#diff-69c8ecaca5e2c7005f2ed1facaa41f80b45bfd006f2357e53ff3072f535c287dL687 and not next to switch_to_skip_aggregation. I can't fully give an explanation yet but allowing the probe to be updated and then allowing the look to break before we get here seems dangerous? It's important that we emit everything before we move to the SkipAggregation state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@korowa does that make sense to you?

timer.done();
self.exec_state = new_state;
break 'reading_input;
}

self.switch_to_skip_aggregation()?;
// Check if we need to emit early due to memory pressure
if let Some(new_state) = self.emit_early_if_necessary()? {
timer.done();
self.exec_state = new_state;
break 'reading_input;
}

timer.done();
}
Expand Down Expand Up @@ -785,6 +796,15 @@ impl Stream for GroupedHashAggregateStream {
}
None => {
// inner is done, switching to `Done` state
// Sanity check: when switching from SkippingAggregation to Done,
// all groups should have already been emitted
if !self.group_values.is_empty() {
return Poll::Ready(Some(internal_err!(
"Switching from SkippingAggregation to Done with {} groups still in hash table. \
This is a bug - all groups should have been emitted before skip aggregation started.",
self.group_values.len()
)));
}
self.exec_state = ExecutionState::Done;
}
}
Expand Down Expand Up @@ -832,6 +852,13 @@ impl Stream for GroupedHashAggregateStream {
}

ExecutionState::Done => {
// Sanity check: all groups should have been emitted by now
if !self.group_values.is_empty() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some protection here to try and avoid bugs like this happening in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather fail the query than return the wrong / incomplete results.

return Poll::Ready(Some(internal_err!(
"AggregateStream was in Done state with {} groups left in hash table. \
This is a bug - all groups should have been emitted before entering Done state.",
self.group_values.len())));
}
// release the memory reservation since sending back output batch itself needs
// some memory reservation, so make some room for it.
self.clear_all();
Expand Down Expand Up @@ -1096,18 +1123,20 @@ impl GroupedHashAggregateStream {
/// Emit if the used memory exceeds the target for partial aggregation.
/// Currently only [`GroupOrdering::None`] is supported for early emitting.
/// TODO: support group_ordering for early emitting
fn emit_early_if_necessary(&mut self) -> Result<()> {
///
/// Returns `Some(ExecutionState)` if the state should be changed, None otherwise.
fn emit_early_if_necessary(&mut self) -> Result<Option<ExecutionState>> {
if self.group_values.len() >= self.batch_size
&& matches!(self.group_ordering, GroupOrdering::None)
&& self.update_memory_reservation().is_err()
{
assert_eq!(self.mode, AggregateMode::Partial);
let n = self.group_values.len() / self.batch_size * self.batch_size;
if let Some(batch) = self.emit(EmitTo::First(n), false)? {
self.exec_state = ExecutionState::ProducingOutput(batch);
return Ok(Some(ExecutionState::ProducingOutput(batch)));
};
}
Ok(())
Ok(None)
}

/// At this point, all the inputs are read and there are some spills.
Expand Down Expand Up @@ -1190,16 +1219,18 @@ impl GroupedHashAggregateStream {
/// skipped, forces stream to produce currently accumulated output.
///
/// Notice: It should only be called in Partial aggregation
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
///
/// Returns `Some(ExecutionState)` if the state should be changed, None otherwise.
fn switch_to_skip_aggregation(&mut self) -> Result<Option<ExecutionState>> {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
if probe.should_skip() {
if let Some(batch) = self.emit(EmitTo::All, false)? {
self.exec_state = ExecutionState::ProducingOutput(batch);
return Ok(Some(ExecutionState::ProducingOutput(batch)));
};
}
}

Ok(())
Ok(None)
}

/// Returns true if the aggregation probe indicates that aggregation
Expand Down Expand Up @@ -1239,3 +1270,123 @@ impl GroupedHashAggregateStream {
Ok(states_batch)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test::TestMemoryExec;
use arrow::array::{Int32Array, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_execution::TaskContext;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::col;
use std::sync::Arc;

#[tokio::test]
async fn test_double_emission_race_condition_bug() -> Result<()> {
// Fix for https://github.com/apache/datafusion/issues/18701
// This test specifically proves that we have fixed double emission race condition
// where emit_early_if_necessary() and switch_to_skip_aggregation()
// both emit in the same loop iteration, causing data loss

let schema = Arc::new(Schema::new(vec![
Field::new("group_col", DataType::Int32, false),
Field::new("value_col", DataType::Int64, false),
]));

// Create data that will trigger BOTH conditions in the same iteration:
// 1. More groups than batch_size (triggers early emission when memory pressure hits)
// 2. High cardinality ratio (triggers skip aggregation)
let batch_size = 1024; // We'll set this in session config
let num_groups = batch_size + 100; // Slightly more than batch_size (1124 groups)

// Create exactly 1 row per group = 100% cardinality ratio
let group_ids: Vec<i32> = (0..num_groups as i32).collect();
let values: Vec<i64> = vec![1; num_groups];

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(group_ids)),
Arc::new(Int64Array::from(values)),
],
)?;

let input_partitions = vec![vec![batch]];

// Create constrained memory to trigger early emission but not completely fail
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(1024, 1.0) // 100KB - enough to start but will trigger pressure
.build_arc()?;

let mut task_ctx = TaskContext::default().with_runtime(runtime);

// Configure to trigger BOTH conditions:
// 1. Low probe threshold (triggers skip probe after few rows)
// 2. Low ratio threshold (triggers skip aggregation immediately)
// 3. Set batch_size to 1024 so our 1124 groups will trigger early emission
// This creates the race condition where both emit paths are triggered
let mut session_config = task_ctx.session_config().clone();
session_config = session_config.set(
"datafusion.execution.batch_size",
&datafusion_common::ScalarValue::UInt64(Some(1024)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
&datafusion_common::ScalarValue::UInt64(Some(50)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
&datafusion_common::ScalarValue::Float64(Some(0.8)),
);
task_ctx = task_ctx.with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);

// Create aggregate: COUNT(*) GROUP BY group_col
let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())];
let aggr_expr = vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?])
.schema(Arc::clone(&schema))
.alias("count_value")
.build()?,
)];

let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
let exec = Arc::new(TestMemoryExec::update_cache(Arc::new(exec)));

// Use Partial mode where the race condition occurs
let aggregate_exec = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(group_expr),
aggr_expr,
vec![None],
exec,
Arc::clone(&schema),
)?;

// Execute and collect results
let mut stream =
GroupedHashAggregateStream::new(&aggregate_exec, Arc::clone(&task_ctx), 0)?;
let mut results = Vec::new();

while let Some(result) = stream.next().await {
let batch = result?;
results.push(batch);
}

// Count total groups emitted
let mut total_output_groups = 0;
for batch in &results {
total_output_groups += batch.num_rows();
}

assert_eq!(
total_output_groups, num_groups,
"Unexpected number of groups",
);

Ok(())
}
}