-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Optimize count agg expr with null column statistics #1063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bf0be0a
ca4e1c8
472c80a
ed7c838
ad5311c
3e28a0f
597338b
fb683c6
6dd40ea
df032a2
c73d6bc
62e0eeb
a47e812
42c01d1
1c65b6a
5f8c9fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,7 +57,13 @@ impl PhysicalOptimizerRule for AggregateStatistics { | |
| let stats = partial_agg_exec.input().statistics(); | ||
| let mut projections = vec![]; | ||
| for expr in partial_agg_exec.aggr_expr() { | ||
| if let Some((num_rows, name)) = take_optimizable_count(&**expr, &stats) { | ||
| if let Some((non_null_rows, name)) = | ||
| take_optimizable_column_count(&**expr, &stats) | ||
| { | ||
| projections.push((expressions::lit(non_null_rows), name.to_owned())); | ||
| } else if let Some((num_rows, name)) = | ||
| take_optimizable_table_count(&**expr, &stats) | ||
| { | ||
| projections.push((expressions::lit(num_rows), name.to_owned())); | ||
| } else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { | ||
| projections.push((expressions::lit(min), name.to_owned())); | ||
|
|
@@ -127,7 +133,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> | |
| } | ||
|
|
||
| /// If this agg_expr is a count that is defined in the statistics, return it | ||
| fn take_optimizable_count( | ||
| fn take_optimizable_table_count( | ||
| agg_expr: &dyn AggregateExpr, | ||
| stats: &Statistics, | ||
| ) -> Option<(ScalarValue, &'static str)> { | ||
|
|
@@ -144,7 +150,40 @@ fn take_optimizable_count( | |
| if lit_expr.value() == &ScalarValue::UInt8(Some(1)) { | ||
| return Some(( | ||
| ScalarValue::UInt64(Some(num_rows as u64)), | ||
| "COUNT(Uint8(1))", | ||
| "COUNT(UInt8(1))", | ||
| )); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| None | ||
| } | ||
|
|
||
| /// If this agg_expr is a count that can be derived from the statistics, return it | ||
| fn take_optimizable_column_count( | ||
| agg_expr: &dyn AggregateExpr, | ||
| stats: &Statistics, | ||
| ) -> Option<(ScalarValue, String)> { | ||
| if let (Some(num_rows), Some(col_stats), Some(casted_expr)) = ( | ||
| stats.num_rows, | ||
| &stats.column_statistics, | ||
| agg_expr.as_any().downcast_ref::<expressions::Count>(), | ||
| ) { | ||
| if casted_expr.expressions().len() == 1 { | ||
| // TODO optimize with exprs other than Column | ||
| if let Some(col_expr) = casted_expr.expressions()[0] | ||
| .as_any() | ||
| .downcast_ref::<expressions::Column>() | ||
| { | ||
| if let ColumnStatistics { | ||
| null_count: Some(val), | ||
| .. | ||
| } = &col_stats[col_expr.index()] | ||
| { | ||
| let expr = format!("COUNT({})", col_expr.name()); | ||
| return Some(( | ||
| ScalarValue::UInt64(Some((num_rows - val) as u64)), | ||
| expr, | ||
| )); | ||
| } | ||
| } | ||
|
|
@@ -237,8 +276,8 @@ mod tests { | |
| let batch = RecordBatch::try_new( | ||
| Arc::clone(&schema), | ||
| vec![ | ||
| Arc::new(Int32Array::from(vec![1, 2, 3])), | ||
| Arc::new(Int32Array::from(vec![4, 5, 6])), | ||
| Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), | ||
| Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), | ||
| ], | ||
| )?; | ||
|
|
||
|
|
@@ -250,38 +289,41 @@ mod tests { | |
| } | ||
|
|
||
| /// Checks that the count optimization was applied and we still get the right result | ||
| async fn assert_count_optim_success(plan: HashAggregateExec) -> Result<()> { | ||
| async fn assert_count_optim_success( | ||
| plan: HashAggregateExec, | ||
| nulls: bool, | ||
| ) -> Result<()> { | ||
| let conf = ExecutionConfig::new(); | ||
| let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; | ||
|
|
||
| let (col, count) = match nulls { | ||
| false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 3), | ||
| true => (Field::new("COUNT(a)", DataType::UInt64, false), 2), | ||
| }; | ||
|
|
||
| // A ProjectionExec is a sign that the count optimization was applied | ||
| assert!(optimized.as_any().is::<ProjectionExec>()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a comment here that the added
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure - added it. |
||
| let result = common::collect(optimized.execute(0).await?).await?; | ||
| assert_eq!( | ||
| result[0].schema(), | ||
| Arc::new(Schema::new(vec![Field::new( | ||
| "COUNT(Uint8(1))", | ||
| DataType::UInt64, | ||
| false | ||
| )])) | ||
| ); | ||
| assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); | ||
| assert_eq!( | ||
| result[0] | ||
| .column(0) | ||
| .as_any() | ||
| .downcast_ref::<UInt64Array>() | ||
| .unwrap() | ||
| .values(), | ||
| &[3] | ||
| &[count] | ||
| ); | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn count_expr() -> Arc<dyn AggregateExpr> { | ||
| Arc::new(Count::new( | ||
| expressions::lit(ScalarValue::UInt8(Some(1))), | ||
| "my_count_alias", | ||
| DataType::UInt64, | ||
| )) | ||
| fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn AggregateExpr> { | ||
| // Return appropriate expr depending if COUNT is for col or table | ||
| let expr = match schema { | ||
| None => expressions::lit(ScalarValue::UInt8(Some(1))), | ||
| Some(s) => expressions::col(col.unwrap(), s).unwrap(), | ||
| }; | ||
| Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64)) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
|
|
@@ -293,20 +335,47 @@ mod tests { | |
| let partial_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr()], | ||
| vec![count_expr(None, None)], | ||
| source, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let final_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr()], | ||
| vec![count_expr(None, None)], | ||
| Arc::new(partial_agg), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| assert_count_optim_success(final_agg).await?; | ||
| assert_count_optim_success(final_agg, false).await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_count_partial_with_nulls_direct_child() -> Result<()> { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not testing the code that you have added, it tests that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thx for picking that up. Looking into it. |
||
| // basic test case with the aggregation applied on a source with exact statistics | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
|
|
||
| let partial_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| source, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let final_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| Arc::new(partial_agg), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| assert_count_optim_success(final_agg, true).await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
@@ -319,7 +388,36 @@ mod tests { | |
| let partial_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr()], | ||
| vec![count_expr(None, None)], | ||
| source, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| // We introduce an intermediate optimization step between the partial and final aggregtator | ||
| let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); | ||
|
|
||
| let final_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(None, None)], | ||
| Arc::new(coalesce), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| assert_count_optim_success(final_agg, false).await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
|
|
||
| let partial_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| source, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
@@ -330,12 +428,12 @@ mod tests { | |
| let final_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr()], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| Arc::new(coalesce), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| assert_count_optim_success(final_agg).await?; | ||
| assert_count_optim_success(final_agg, true).await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
@@ -359,15 +457,57 @@ mod tests { | |
| let partial_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr()], | ||
| vec![count_expr(None, None)], | ||
| filter, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let final_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(None, None)], | ||
| Arc::new(partial_agg), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let conf = ExecutionConfig::new(); | ||
| let optimized = | ||
| AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; | ||
|
|
||
| // check that the original ExecutionPlan was not replaced | ||
| assert!(optimized.as_any().is::<HashAggregateExec>()); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_count_with_nulls_inexact_stat() -> Result<()> { | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
|
|
||
| // adding a filter makes the statistics inexact | ||
| let filter = Arc::new(FilterExec::try_new( | ||
| expressions::binary( | ||
| expressions::col("a", &schema)?, | ||
| Operator::Gt, | ||
| expressions::lit(ScalarValue::from(1u32)), | ||
| &schema, | ||
| )?, | ||
| source, | ||
| )?); | ||
|
|
||
| let partial_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| filter, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let final_agg = HashAggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr()], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| Arc::new(partial_agg), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like this code handles
count(col)whereas the code above only handlescount(*)-- that seems strange -- perhaps we should update it so both can handlecount(col)andcount(*)?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that
COUNT(*)doesnt need to have a separate handler for nulls - assuming we expect same behavior as psql. For example in psql when i do the following:Does it make sense to reframe these optimizations as the following:
take_optimizable_table_count(currenttake_optimizable_count)=> comes fromCOUNT(*)and returnsnum_rowstake_optimizable_column_count(currenttake_optimizable_count_with_nulls) => comes fromCOUNT(col)and returnnum_rows - null_countforcolThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think those names make more sense to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - ive updated. Let me know if anything else needed.