-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Fix Partial AggregateExec correctness issue dropping rows #18712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d9e1d85
416f19d
7449fa6
36df60c
febdd78
b1bde9e
ac72aba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -684,8 +684,6 @@ impl Stream for GroupedHashAggregateStream { | |
| // Do the grouping | ||
| self.group_aggregate_batch(batch)?; | ||
|
|
||
| self.update_skip_aggregation_probe(input_rows); | ||
|
|
||
| // If we can begin emitting rows, do so, | ||
| // otherwise keep consuming input | ||
| assert!(!self.input_done); | ||
|
|
@@ -709,9 +707,22 @@ impl Stream for GroupedHashAggregateStream { | |
| break 'reading_input; | ||
| } | ||
|
|
||
| self.emit_early_if_necessary()?; | ||
| // Check if we should switch to skip aggregation mode | ||
| // It's important that we do this before we early emit since we've | ||
| // already updated the probe. | ||
| self.update_skip_aggregation_probe(input_rows); | ||
| if let Some(new_state) = self.switch_to_skip_aggregation()? { | ||
| timer.done(); | ||
| self.exec_state = new_state; | ||
| break 'reading_input; | ||
| } | ||
|
|
||
| self.switch_to_skip_aggregation()?; | ||
| // Check if we need to emit early due to memory pressure | ||
| if let Some(new_state) = self.emit_early_if_necessary()? { | ||
| timer.done(); | ||
| self.exec_state = new_state; | ||
| break 'reading_input; | ||
| } | ||
|
|
||
| timer.done(); | ||
| } | ||
|
|
@@ -785,6 +796,15 @@ impl Stream for GroupedHashAggregateStream { | |
| } | ||
| None => { | ||
| // inner is done, switching to `Done` state | ||
| // Sanity check: when switching from SkippingAggregation to Done, | ||
| // all groups should have already been emitted | ||
| if !self.group_values.is_empty() { | ||
| return Poll::Ready(Some(internal_err!( | ||
| "Switching from SkippingAggregation to Done with {} groups still in hash table. \ | ||
| This is a bug - all groups should have been emitted before skip aggregation started.", | ||
| self.group_values.len() | ||
| ))); | ||
| } | ||
| self.exec_state = ExecutionState::Done; | ||
| } | ||
| } | ||
|
|
@@ -832,6 +852,13 @@ impl Stream for GroupedHashAggregateStream { | |
| } | ||
|
|
||
| ExecutionState::Done => { | ||
| // Sanity check: all groups should have been emitted by now | ||
| if !self.group_values.is_empty() { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding some protection here to try and avoid bugs like this happening in the future.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather fail the query than return the wrong / incomplete results. |
||
| return Poll::Ready(Some(internal_err!( | ||
| "AggregateStream was in Done state with {} groups left in hash table. \ | ||
| This is a bug - all groups should have been emitted before entering Done state.", | ||
| self.group_values.len()))); | ||
| } | ||
| // release the memory reservation since sending back output batch itself needs | ||
| // some memory reservation, so make some room for it. | ||
| self.clear_all(); | ||
|
|
@@ -1096,18 +1123,20 @@ impl GroupedHashAggregateStream { | |
| /// Emit if the used memory exceeds the target for partial aggregation. | ||
| /// Currently only [`GroupOrdering::None`] is supported for early emitting. | ||
| /// TODO: support group_ordering for early emitting | ||
| fn emit_early_if_necessary(&mut self) -> Result<()> { | ||
| /// | ||
| /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise. | ||
| fn emit_early_if_necessary(&mut self) -> Result<Option<ExecutionState>> { | ||
| if self.group_values.len() >= self.batch_size | ||
| && matches!(self.group_ordering, GroupOrdering::None) | ||
| && self.update_memory_reservation().is_err() | ||
| { | ||
| assert_eq!(self.mode, AggregateMode::Partial); | ||
| let n = self.group_values.len() / self.batch_size * self.batch_size; | ||
| if let Some(batch) = self.emit(EmitTo::First(n), false)? { | ||
| self.exec_state = ExecutionState::ProducingOutput(batch); | ||
| return Ok(Some(ExecutionState::ProducingOutput(batch))); | ||
| }; | ||
| } | ||
| Ok(()) | ||
| Ok(None) | ||
| } | ||
|
|
||
| /// At this point, all the inputs are read and there are some spills. | ||
|
|
@@ -1190,16 +1219,18 @@ impl GroupedHashAggregateStream { | |
| /// skipped, forces stream to produce currently accumulated output. | ||
| /// | ||
| /// Notice: It should only be called in Partial aggregation | ||
| fn switch_to_skip_aggregation(&mut self) -> Result<()> { | ||
| /// | ||
| /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise. | ||
| fn switch_to_skip_aggregation(&mut self) -> Result<Option<ExecutionState>> { | ||
| if let Some(probe) = self.skip_aggregation_probe.as_mut() { | ||
| if probe.should_skip() { | ||
| if let Some(batch) = self.emit(EmitTo::All, false)? { | ||
| self.exec_state = ExecutionState::ProducingOutput(batch); | ||
| return Ok(Some(ExecutionState::ProducingOutput(batch))); | ||
| }; | ||
| } | ||
| } | ||
|
|
||
| Ok(()) | ||
| Ok(None) | ||
| } | ||
|
|
||
| /// Returns true if the aggregation probe indicates that aggregation | ||
|
|
@@ -1239,3 +1270,123 @@ impl GroupedHashAggregateStream { | |
| Ok(states_batch) | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use crate::test::TestMemoryExec; | ||
| use arrow::array::{Int32Array, Int64Array}; | ||
| use arrow::datatypes::{DataType, Field, Schema}; | ||
| use datafusion_execution::runtime_env::RuntimeEnvBuilder; | ||
| use datafusion_execution::TaskContext; | ||
| use datafusion_functions_aggregate::count::count_udaf; | ||
| use datafusion_physical_expr::aggregate::AggregateExprBuilder; | ||
| use datafusion_physical_expr::expressions::col; | ||
| use std::sync::Arc; | ||
|
|
||
| #[tokio::test] | ||
| async fn test_double_emission_race_condition_bug() -> Result<()> { | ||
| // Fix for https://github.com/apache/datafusion/issues/18701 | ||
| // This test specifically proves that we have fixed double emission race condition | ||
| // where emit_early_if_necessary() and switch_to_skip_aggregation() | ||
| // both emit in the same loop iteration, causing data loss | ||
|
|
||
| let schema = Arc::new(Schema::new(vec![ | ||
| Field::new("group_col", DataType::Int32, false), | ||
| Field::new("value_col", DataType::Int64, false), | ||
| ])); | ||
|
|
||
| // Create data that will trigger BOTH conditions in the same iteration: | ||
| // 1. More groups than batch_size (triggers early emission when memory pressure hits) | ||
| // 2. High cardinality ratio (triggers skip aggregation) | ||
| let batch_size = 1024; // We'll set this in session config | ||
| let num_groups = batch_size + 100; // Slightly more than batch_size (1124 groups) | ||
|
|
||
| // Create exactly 1 row per group = 100% cardinality ratio | ||
| let group_ids: Vec<i32> = (0..num_groups as i32).collect(); | ||
| let values: Vec<i64> = vec![1; num_groups]; | ||
|
|
||
| let batch = RecordBatch::try_new( | ||
| Arc::clone(&schema), | ||
| vec![ | ||
| Arc::new(Int32Array::from(group_ids)), | ||
| Arc::new(Int64Array::from(values)), | ||
| ], | ||
| )?; | ||
|
|
||
| let input_partitions = vec![vec![batch]]; | ||
|
|
||
| // Create constrained memory to trigger early emission but not completely fail | ||
| let runtime = RuntimeEnvBuilder::default() | ||
| .with_memory_limit(1024, 1.0) // 100KB - enough to start but will trigger pressure | ||
| .build_arc()?; | ||
|
|
||
| let mut task_ctx = TaskContext::default().with_runtime(runtime); | ||
|
|
||
| // Configure to trigger BOTH conditions: | ||
| // 1. Low probe threshold (triggers skip probe after few rows) | ||
| // 2. Low ratio threshold (triggers skip aggregation immediately) | ||
| // 3. Set batch_size to 1024 so our 1124 groups will trigger early emission | ||
| // This creates the race condition where both emit paths are triggered | ||
| let mut session_config = task_ctx.session_config().clone(); | ||
| session_config = session_config.set( | ||
| "datafusion.execution.batch_size", | ||
| &datafusion_common::ScalarValue::UInt64(Some(1024)), | ||
| ); | ||
| session_config = session_config.set( | ||
| "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", | ||
| &datafusion_common::ScalarValue::UInt64(Some(50)), | ||
| ); | ||
| session_config = session_config.set( | ||
| "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", | ||
| &datafusion_common::ScalarValue::Float64(Some(0.8)), | ||
| ); | ||
| task_ctx = task_ctx.with_session_config(session_config); | ||
| let task_ctx = Arc::new(task_ctx); | ||
|
|
||
| // Create aggregate: COUNT(*) GROUP BY group_col | ||
| let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; | ||
| let aggr_expr = vec![Arc::new( | ||
| AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) | ||
| .schema(Arc::clone(&schema)) | ||
| .alias("count_value") | ||
| .build()?, | ||
| )]; | ||
|
|
||
| let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; | ||
| let exec = Arc::new(TestMemoryExec::update_cache(Arc::new(exec))); | ||
|
|
||
| // Use Partial mode where the race condition occurs | ||
| let aggregate_exec = AggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| PhysicalGroupBy::new_single(group_expr), | ||
| aggr_expr, | ||
| vec![None], | ||
| exec, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| // Execute and collect results | ||
| let mut stream = | ||
| GroupedHashAggregateStream::new(&aggregate_exec, Arc::clone(&task_ctx), 0)?; | ||
| let mut results = Vec::new(); | ||
|
|
||
| while let Some(result) = stream.next().await { | ||
| let batch = result?; | ||
| results.push(batch); | ||
| } | ||
|
|
||
| // Count total groups emitted | ||
| let mut total_output_groups = 0; | ||
| for batch in &results { | ||
| total_output_groups += batch.num_rows(); | ||
| } | ||
|
|
||
| assert_eq!( | ||
| total_output_groups, num_groups, | ||
| "Unexpected number of groups", | ||
| ); | ||
|
|
||
| Ok(()) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear to me why we set
update_skip_aggregation_probehere https://github.com/apache/datafusion/pull/18712/files#diff-69c8ecaca5e2c7005f2ed1facaa41f80b45bfd006f2357e53ff3072f535c287dL687 and not next toswitch_to_skip_aggregation. I can't fully give an explanation yet but allowing the probe to be updated and then allowing the look to break before we get here seems dangerous? It's important that we emit everything before we move to theSkipAggregationstate?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@korowa does that make sense to you?