diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 3c6577af4286..f2b7bc0ebc02 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -684,8 +684,6 @@ impl Stream for GroupedHashAggregateStream { // Do the grouping self.group_aggregate_batch(batch)?; - self.update_skip_aggregation_probe(input_rows); - // If we can begin emitting rows, do so, // otherwise keep consuming input assert!(!self.input_done); @@ -709,9 +707,22 @@ impl Stream for GroupedHashAggregateStream { break 'reading_input; } - self.emit_early_if_necessary()?; + // Check if we should switch to skip aggregation mode + // It's important that we do this before we early emit since we've + // already updated the probe. + self.update_skip_aggregation_probe(input_rows); + if let Some(new_state) = self.switch_to_skip_aggregation()? { + timer.done(); + self.exec_state = new_state; + break 'reading_input; + } - self.switch_to_skip_aggregation()?; + // Check if we need to emit early due to memory pressure + if let Some(new_state) = self.emit_early_if_necessary()? { + timer.done(); + self.exec_state = new_state; + break 'reading_input; + } timer.done(); } @@ -785,6 +796,15 @@ impl Stream for GroupedHashAggregateStream { } None => { // inner is done, switching to `Done` state + // Sanity check: when switching from SkippingAggregation to Done, + // all groups should have already been emitted + if !self.group_values.is_empty() { + return Poll::Ready(Some(internal_err!( + "Switching from SkippingAggregation to Done with {} groups still in hash table. \ + This is a bug - all groups should have been emitted before skip aggregation started.", + self.group_values.len() + ))); + } self.exec_state = ExecutionState::Done; } } @@ -832,6 +852,13 @@ impl Stream for GroupedHashAggregateStream { } ExecutionState::Done => { + // Sanity check: all groups should have been emitted by now + if !self.group_values.is_empty() { + return Poll::Ready(Some(internal_err!( + "AggregateStream was in Done state with {} groups left in hash table. \ + This is a bug - all groups should have been emitted before entering Done state.", + self.group_values.len()))); + } // release the memory reservation since sending back output batch itself needs // some memory reservation, so make some room for it. self.clear_all(); @@ -1096,7 +1123,9 @@ impl GroupedHashAggregateStream { /// Emit if the used memory exceeds the target for partial aggregation. /// Currently only [`GroupOrdering::None`] is supported for early emitting. /// TODO: support group_ordering for early emitting - fn emit_early_if_necessary(&mut self) -> Result<()> { + /// + /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise. + fn emit_early_if_necessary(&mut self) -> Result> { if self.group_values.len() >= self.batch_size && matches!(self.group_ordering, GroupOrdering::None) && self.update_memory_reservation().is_err() @@ -1104,10 +1133,10 @@ impl GroupedHashAggregateStream { assert_eq!(self.mode, AggregateMode::Partial); let n = self.group_values.len() / self.batch_size * self.batch_size; if let Some(batch) = self.emit(EmitTo::First(n), false)? { - self.exec_state = ExecutionState::ProducingOutput(batch); + return Ok(Some(ExecutionState::ProducingOutput(batch))); }; } - Ok(()) + Ok(None) } /// At this point, all the inputs are read and there are some spills. @@ -1190,16 +1219,18 @@ impl GroupedHashAggregateStream { /// skipped, forces stream to produce currently accumulated output. /// /// Notice: It should only be called in Partial aggregation - fn switch_to_skip_aggregation(&mut self) -> Result<()> { + /// + /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise. + fn switch_to_skip_aggregation(&mut self) -> Result> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { if let Some(batch) = self.emit(EmitTo::All, false)? { - self.exec_state = ExecutionState::ProducingOutput(batch); + return Ok(Some(ExecutionState::ProducingOutput(batch))); }; } } - Ok(()) + Ok(None) } /// Returns true if the aggregation probe indicates that aggregation @@ -1239,3 +1270,123 @@ impl GroupedHashAggregateStream { Ok(states_batch) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::TestMemoryExec; + use arrow::array::{Int32Array, Int64Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::col; + use std::sync::Arc; + + #[tokio::test] + async fn test_double_emission_race_condition_bug() -> Result<()> { + // Fix for https://github.com/apache/datafusion/issues/18701 + // This test specifically proves that we have fixed double emission race condition + // where emit_early_if_necessary() and switch_to_skip_aggregation() + // both emit in the same loop iteration, causing data loss + + let schema = Arc::new(Schema::new(vec![ + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + // Create data that will trigger BOTH conditions in the same iteration: + // 1. More groups than batch_size (triggers early emission when memory pressure hits) + // 2. High cardinality ratio (triggers skip aggregation) + let batch_size = 1024; // We'll set this in session config + let num_groups = batch_size + 100; // Slightly more than batch_size (1124 groups) + + // Create exactly 1 row per group = 100% cardinality ratio + let group_ids: Vec = (0..num_groups as i32).collect(); + let values: Vec = vec![1; num_groups]; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(group_ids)), + Arc::new(Int64Array::from(values)), + ], + )?; + + let input_partitions = vec![vec![batch]]; + + // Create constrained memory to trigger early emission but not completely fail + let runtime = RuntimeEnvBuilder::default() + .with_memory_limit(1024, 1.0) // 100KB - enough to start but will trigger pressure + .build_arc()?; + + let mut task_ctx = TaskContext::default().with_runtime(runtime); + + // Configure to trigger BOTH conditions: + // 1. Low probe threshold (triggers skip probe after few rows) + // 2. Low ratio threshold (triggers skip aggregation immediately) + // 3. Set batch_size to 1024 so our 1124 groups will trigger early emission + // This creates the race condition where both emit paths are triggered + let mut session_config = task_ctx.session_config().clone(); + session_config = session_config.set( + "datafusion.execution.batch_size", + &datafusion_common::ScalarValue::UInt64(Some(1024)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &datafusion_common::ScalarValue::UInt64(Some(50)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &datafusion_common::ScalarValue::Float64(Some(0.8)), + ); + task_ctx = task_ctx.with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + + // Create aggregate: COUNT(*) GROUP BY group_col + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )]; + + let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; + let exec = Arc::new(TestMemoryExec::update_cache(Arc::new(exec))); + + // Use Partial mode where the race condition occurs + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )?; + + // Execute and collect results + let mut stream = + GroupedHashAggregateStream::new(&aggregate_exec, Arc::clone(&task_ctx), 0)?; + let mut results = Vec::new(); + + while let Some(result) = stream.next().await { + let batch = result?; + results.push(batch); + } + + // Count total groups emitted + let mut total_output_groups = 0; + for batch in &results { + total_output_groups += batch.num_rows(); + } + + assert_eq!( + total_output_groups, num_groups, + "Unexpected number of groups", + ); + + Ok(()) + } +}