diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4abe3ce0edc4..396e66972f30 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -412,6 +412,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], source, Arc::clone(&schema), )?; @@ -421,6 +422,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -442,6 +444,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], source, Arc::clone(&schema), )?; @@ -451,6 +454,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -471,6 +475,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], source, Arc::clone(&schema), )?; @@ -483,6 +488,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], Arc::new(coalesce), Arc::clone(&schema), )?; @@ -503,6 +509,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], source, Arc::clone(&schema), )?; @@ -515,6 +522,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], Arc::new(coalesce), Arc::clone(&schema), )?; @@ -546,6 +554,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], filter, Arc::clone(&schema), )?; @@ -555,6 +564,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -591,6 +601,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], filter, Arc::clone(&schema), )?; @@ -600,6 +611,7 @@ mod tests { PhysicalGroupBy::default(), vec![agg.count_expr()], vec![None], + vec![None], Arc::new(partial_agg), Arc::clone(&schema), )?; diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 59e5fc95edbb..5657c62921f9 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -73,6 +73,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { group_by: input_group_by, aggr_expr: input_aggr_expr, filter_expr: input_filter_expr, + order_by_expr: input_order_by_expr, input_schema, .. }| { @@ -95,6 +96,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_group_by.clone(), input_aggr_expr.to_vec(), input_filter_expr.to_vec(), + input_order_by_expr.to_vec(), partial_input.clone(), input_schema.clone(), ) @@ -279,6 +281,7 @@ mod tests { group_by, aggr_expr, vec![], + vec![], input, schema, ) @@ -298,6 +301,7 @@ mod tests { group_by, aggr_expr, vec![], + vec![], input, schema, ) diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs index e6fd15b7ce88..4c30170ace3d 100644 --- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs @@ -42,7 +42,7 @@ use datafusion_physical_expr::expressions::NoOp; use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, AggregateExpr, - PhysicalExpr, + PhysicalExpr, PhysicalSortExpr, }; use std::sync::Arc; @@ -254,6 +254,7 @@ fn adjust_input_keys_ordering( group_by, aggr_expr, filter_expr, + order_by_expr, input, input_schema, .. @@ -267,6 +268,7 @@ fn adjust_input_keys_ordering( group_by, aggr_expr, filter_expr, + order_by_expr, input.clone(), input_schema, )?), @@ -367,12 +369,14 @@ where } } +#[allow(clippy::too_many_arguments)] fn reorder_aggregate_keys( agg_plan: Arc, parent_required: &[Arc], group_by: &PhysicalGroupBy, aggr_expr: &[Arc], filter_expr: &[Option>], + order_by_expr: &[Option>], agg_input: Arc, input_schema: &SchemaRef, ) -> Result { @@ -403,6 +407,7 @@ fn reorder_aggregate_keys( group_by, aggr_expr, filter_expr, + order_by_expr, input, input_schema, .. @@ -422,6 +427,7 @@ fn reorder_aggregate_keys( new_partial_group_by, aggr_expr.clone(), filter_expr.clone(), + order_by_expr.clone(), input.clone(), input_schema.clone(), )?)) @@ -453,6 +459,7 @@ fn reorder_aggregate_keys( new_group_by, aggr_expr.to_vec(), filter_expr.to_vec(), + order_by_expr.to_vec(), partial_agg, input_schema.clone(), )?); @@ -1104,12 +1111,14 @@ mod tests { final_grouping, vec![], vec![], + vec![], Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, vec![], vec![], + vec![], input, schema.clone(), ) diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs index 1b52e4e27468..8b407ed2891b 100644 --- a/datafusion/core/src/physical_optimizer/repartition.rs +++ b/datafusion/core/src/physical_optimizer/repartition.rs @@ -478,12 +478,14 @@ mod tests { PhysicalGroupBy::default(), vec![], vec![], + vec![], Arc::new( AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), vec![], vec![], + vec![], input, schema.clone(), ) diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index 3d9363f34d84..f71c79e9fc82 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -2867,6 +2867,7 @@ mod tests { PhysicalGroupBy::default(), vec![], vec![], + vec![], input, schema, ) diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index ffd985513dbc..247dfc27784e 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -41,8 +41,9 @@ use datafusion_physical_expr::{ equivalence::project_equivalence_properties, expressions::{Avg, CastExpr, Column, Sum}, normalize_out_expr_with_columns_map, - utils::{convert_to_expr, get_indices_of_matching_exprs}, - AggregateExpr, PhysicalExpr, PhysicalSortExpr, + utils::{convert_to_expr, get_indices_of_matching_exprs, ordering_satisfy_concrete}, + AggregateExpr, OrderingEquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, }; use std::any::Any; use std::collections::HashMap; @@ -228,6 +229,8 @@ pub struct AggregateExec { pub(crate) aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression pub(crate) filter_expr: Vec>>, + /// (ORDER BY clause) expression for each aggregate expression + pub(crate) order_by_expr: Vec>>, /// Input plan, could be a partial aggregate or the input to the aggregate pub(crate) input: Arc, /// Schema after the aggregate is applied @@ -243,6 +246,7 @@ pub struct AggregateExec { metrics: ExecutionPlanMetricsSet, /// Stores mode and output ordering information for the `AggregateExec`. aggregation_ordering: Option, + required_input_ordering: Option>, } /// Calculates the working mode for `GROUP BY` queries. @@ -337,6 +341,54 @@ fn output_group_expr_helper(group_by: &PhysicalGroupBy) -> Vec EquivalenceProperties, + F2: Fn() -> OrderingEquivalenceProperties, +>( + order_by_expr: &[Option>], + eq_properties: F, + ordering_eq_properties: F2, +) -> Result>> { + let mut result: Option> = None; + for fn_reqs in order_by_expr.iter().flatten() { + if let Some(result) = &mut result { + if ordering_satisfy_concrete( + result, + fn_reqs, + &eq_properties, + &ordering_eq_properties, + ) { + // Do not update the result as it already satisfies current + // function's requirement: + continue; + } + if ordering_satisfy_concrete( + fn_reqs, + result, + &eq_properties, + &ordering_eq_properties, + ) { + // Update result with current function's requirements, as it is + // a finer requirement than what we currently have. + *result = fn_reqs.clone(); + continue; + } + // If neither of the requirements satisfy the other, this means + // requirements are conflicting. Currently, we do not support + // conflicting requirements. + return Err(DataFusionError::NotImplemented( + "Conflicting ordering requirements in aggregate functions is not supported".to_string(), + )); + } else { + result = Some(fn_reqs.clone()); + } + } + Ok(result) +} + impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( @@ -344,6 +396,7 @@ impl AggregateExec { group_by: PhysicalGroupBy, aggr_expr: Vec>, filter_expr: Vec>>, + order_by_expr: Vec>>, input: Arc, input_schema: SchemaRef, ) -> Result { @@ -356,6 +409,19 @@ impl AggregateExec { )?; let schema = Arc::new(schema); + let mut aggregator_requirement = None; + // Ordering requirement makes sense only in Partial and Single modes. + // In other modes, all groups are collapsed, therefore their input schema + // can not contain expressions in the requirement. + if mode == AggregateMode::Partial || mode == AggregateMode::Single { + let requirement = get_finest_requirement( + &order_by_expr, + || input.equivalence_properties(), + || input.ordering_equivalence_properties(), + )?; + aggregator_requirement = requirement + .map(|exprs| PhysicalSortRequirement::from_sort_exprs(exprs.iter())); + } // construct a map from the input columns to the output columns of the Aggregation let mut columns_map: HashMap> = HashMap::new(); @@ -369,17 +435,52 @@ impl AggregateExec { let aggregation_ordering = calc_aggregation_ordering(&input, &group_by); + let mut required_input_ordering = None; + if let Some(AggregationOrdering { + ordering, + // If the mode is FullyOrdered or PartiallyOrdered (i.e. we are + // running with bounded memory, without breaking pipeline), then + // we append aggregator ordering requirement to the existing + // ordering. This way, we can still run with bounded memory. + mode: GroupByOrderMode::FullyOrdered | GroupByOrderMode::PartiallyOrdered, + .. + }) = &aggregation_ordering + { + if let Some(aggregator_requirement) = aggregator_requirement { + // Get the section of the input ordering that enables us to run in the + // FullyOrdered or PartiallyOrdered mode: + let requirement_prefix = + if let Some(existing_ordering) = input.output_ordering() { + existing_ordering[0..ordering.len()].to_vec() + } else { + vec![] + }; + let mut requirement = + PhysicalSortRequirement::from_sort_exprs(requirement_prefix.iter()); + for req in aggregator_requirement { + if requirement.iter().all(|item| req.expr.ne(&item.expr)) { + requirement.push(req); + } + } + required_input_ordering = Some(requirement); + } + } else { + required_input_ordering = aggregator_requirement; + } + Ok(AggregateExec { mode, group_by, aggr_expr, filter_expr, + order_by_expr, input, schema, input_schema, columns_map, metrics: ExecutionPlanMetricsSet::new(), aggregation_ordering, + required_input_ordering, }) } @@ -408,6 +509,11 @@ impl AggregateExec { &self.filter_expr } + /// ORDER BY clause expression for each aggregate expression + pub fn order_by_expr(&self) -> &[Option>] { + &self.order_by_expr + } + /// Input plan pub fn input(&self) -> &Arc { &self.input @@ -547,6 +653,10 @@ impl ExecutionPlan for AggregateExec { } } + fn required_input_ordering(&self) -> Vec>> { + vec![self.required_input_ordering.clone()] + } + fn equivalence_properties(&self) -> EquivalenceProperties { let mut new_properties = EquivalenceProperties::new(self.schema()); project_equivalence_properties( @@ -570,6 +680,7 @@ impl ExecutionPlan for AggregateExec { self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), + self.order_by_expr.clone(), children[0].clone(), self.input_schema.clone(), )?)) @@ -951,7 +1062,8 @@ mod tests { use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::from_slice::FromSlice; use crate::physical_plan::aggregates::{ - get_working_mode, AggregateExec, AggregateMode, PhysicalGroupBy, + get_finest_requirement, get_working_mode, AggregateExec, AggregateMode, + PhysicalGroupBy, }; use crate::physical_plan::expressions::{col, Avg}; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; @@ -962,8 +1074,13 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, ScalarValue}; - use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median}; - use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; + use datafusion_physical_expr::expressions::{ + lit, ApproxDistinct, Column, Count, Median, + }; + use datafusion_physical_expr::{ + AggregateExpr, EquivalenceProperties, OrderedColumn, + OrderingEquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + }; use futures::{FutureExt, Stream}; use std::any::Any; use std::sync::Arc; @@ -1130,6 +1247,7 @@ mod tests { grouping_set.clone(), aggregates.clone(), vec![None], + vec![None], input, input_schema.clone(), )?); @@ -1173,6 +1291,7 @@ mod tests { final_grouping_set, aggregates, vec![None], + vec![None], merge, input_schema, )?); @@ -1235,6 +1354,7 @@ mod tests { grouping_set.clone(), aggregates.clone(), vec![None], + vec![None], input, input_schema.clone(), )?); @@ -1268,6 +1388,7 @@ mod tests { final_grouping_set, aggregates, vec![None], + vec![None], merge, input_schema, )?); @@ -1482,6 +1603,7 @@ mod tests { groups, aggregates, vec![None; 3], + vec![None; 3], input.clone(), input_schema.clone(), )?); @@ -1538,6 +1660,7 @@ mod tests { groups.clone(), aggregates.clone(), vec![None], + vec![None], blocking_exec, schema, )?); @@ -1577,6 +1700,7 @@ mod tests { groups, aggregates.clone(), vec![None], + vec![None], blocking_exec, schema, )?); @@ -1590,4 +1714,63 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_get_finest_requirements() -> Result<()> { + let test_schema = create_test_schema()?; + // Assume column a and b are aliases + // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). + let options1 = SortOptions { + descending: false, + nulls_first: false, + }; + let options2 = SortOptions { + descending: true, + nulls_first: true, + }; + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let col_a = Column::new("a", 0); + let col_b = Column::new("b", 1); + let col_c = Column::new("c", 2); + let col_d = Column::new("d", 3); + eq_properties.add_equal_conditions((&col_a, &col_b)); + let mut ordering_eq_properties = OrderingEquivalenceProperties::new(test_schema); + ordering_eq_properties.add_equal_conditions(( + &OrderedColumn::new(col_a.clone(), options1), + &OrderedColumn::new(col_c.clone(), options2), + )); + + let order_by_exprs = vec![ + None, + Some(vec![PhysicalSortExpr { + expr: Arc::new(col_a.clone()), + options: options1, + }]), + Some(vec![PhysicalSortExpr { + expr: Arc::new(col_b.clone()), + options: options1, + }]), + Some(vec![PhysicalSortExpr { + expr: Arc::new(col_c), + options: options2, + }]), + Some(vec![ + PhysicalSortExpr { + expr: Arc::new(col_a), + options: options1, + }, + PhysicalSortExpr { + expr: Arc::new(col_d), + options: options1, + }, + ]), + ]; + let res = get_finest_requirement( + &order_by_exprs, + || eq_properties.clone(), + || ordering_eq_properties.clone(), + )?; + assert_eq!(res, order_by_exprs[4]); + Ok(()) + } } diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index cc52739f6acb..57c0251d243c 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -74,7 +74,7 @@ use datafusion_physical_expr::expressions::Literal; use datafusion_sql::utils::window_expr_common_partition_keys; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; -use itertools::Itertools; +use itertools::{multiunzip, Itertools}; use log::{debug, trace}; use std::collections::HashMap; use std::fmt::Write; @@ -199,12 +199,23 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { args, .. }) => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => { + Expr::AggregateUDF(AggregateUDF { + fun, + args, + filter, + order_by, + }) => { + // TODO: Add support for filter and order by in AggregateUDF if filter.is_some() { return Err(DataFusionError::Execution( "aggregate expression with filter is not supported".to_string(), )); } + if order_by.is_some() { + return Err(DataFusionError::Execution( + "aggregate expression with order_by is not supported".to_string(), + )); + } let mut names = Vec::with_capacity(args.len()); for e in args { names.push(create_physical_name(e, false)?); @@ -703,13 +714,15 @@ impl DefaultPhysicalPlanner { ) }) .collect::>>()?; - let (aggregates, filters): (Vec<_>, Vec<_>) = agg_filter.into_iter().unzip(); + + let (aggregates, filters, order_bys) : (Vec<_>, Vec<_>, Vec<_>) = multiunzip(agg_filter.into_iter()); let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), aggregates.clone(), filters.clone(), + order_bys.clone(), input_exec, physical_input_schema.clone(), )?); @@ -746,6 +759,7 @@ impl DefaultPhysicalPlanner { final_grouping_set, aggregates, filters, + order_bys, initial_aggr, physical_input_schema.clone(), )?)) @@ -867,25 +881,12 @@ impl DefaultPhysicalPlanner { let input_dfschema = input.as_ref().schema(); let sort_expr = expr .iter() - .map(|e| match e { - Expr::Sort(expr::Sort { - expr, - asc, - nulls_first, - }) => create_physical_sort_expr( - expr, - input_dfschema, - &input_schema, - SortOptions { - descending: !*asc, - nulls_first: *nulls_first, - }, - session_state.execution_props(), - ), - _ => Err(DataFusionError::Plan( - "Sort only accepts sort expressions".to_string(), - )), - }) + .map(|e| create_physical_sort_expr( + e, + input_dfschema, + &input_schema, + session_state.execution_props(), + )) .collect::>>()?; let new_sort = SortExec::new(sort_expr, physical_input) .with_fetch(*fetch); @@ -1554,24 +1555,13 @@ pub fn create_window_expr_with_name( .collect::>>()?; let order_by = order_by .iter() - .map(|e| match e { - Expr::Sort(expr::Sort { - expr, - asc, - nulls_first, - }) => create_physical_sort_expr( - expr, + .map(|e| { + create_physical_sort_expr( + e, logical_input_schema, physical_input_schema, - SortOptions { - descending: !*asc, - nulls_first: *nulls_first, - }, execution_props, - ), - _ => Err(DataFusionError::Plan( - "Sort only accepts sort expressions".to_string(), - )), + ) }) .collect::>>()?; if !is_window_valid(window_frame) { @@ -1619,8 +1609,13 @@ pub fn create_window_expr( ) } -type AggregateExprWithOptionalFilter = - (Arc, Option>); +type AggregateExprWithOptionalArgs = ( + Arc, + // The filter clause, if any + Option>, + // Ordering requirements, if any + Option>, +); /// Create an aggregate expression with a name from a logical expression pub fn create_aggregate_expr_with_name_and_maybe_filter( @@ -1629,13 +1624,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( logical_input_schema: &DFSchema, physical_input_schema: &Schema, execution_props: &ExecutionProps, -) -> Result { +) -> Result { match e { Expr::AggregateFunction(AggregateFunction { fun, distinct, args, filter, + order_by, }) => { let args = args .iter() @@ -1663,10 +1659,30 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( &args, physical_input_schema, name, - ); - Ok((agg_expr?, filter)) + )?; + let order_by = match order_by { + Some(e) => Some( + e.iter() + .map(|expr| { + create_physical_sort_expr( + expr, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?, + ), + None => None, + }; + Ok((agg_expr, filter, order_by)) } - Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => { + Expr::AggregateUDF(AggregateUDF { + fun, + args, + filter, + order_by, + }) => { let args = args .iter() .map(|e| { @@ -1688,10 +1704,25 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; + let order_by = match order_by { + Some(e) => Some( + e.iter() + .map(|expr| { + create_physical_sort_expr( + expr, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?, + ), + None => None, + }; let agg_expr = udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); - Ok((agg_expr?, filter)) + Ok((agg_expr?, filter, order_by)) } other => Err(DataFusionError::Internal(format!( "Invalid aggregate expression '{other:?}'" @@ -1705,7 +1736,7 @@ pub fn create_aggregate_expr_and_maybe_filter( logical_input_schema: &DFSchema, physical_input_schema: &Schema, execution_props: &ExecutionProps, -) -> Result { +) -> Result { // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()), @@ -1726,13 +1757,31 @@ pub fn create_physical_sort_expr( e: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - options: SortOptions, execution_props: &ExecutionProps, ) -> Result { - Ok(PhysicalSortExpr { - expr: create_physical_expr(e, input_dfschema, input_schema, execution_props)?, - options, - }) + if let Expr::Sort(expr::Sort { + expr, + asc, + nulls_first, + }) = e + { + Ok(PhysicalSortExpr { + expr: create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?, + options: SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + }) + } else { + Err(DataFusionError::Internal( + "Expects a sort expression".to_string(), + )) + } } impl DefaultPhysicalPlanner { diff --git a/datafusion/core/tests/aggregate_fuzz.rs b/datafusion/core/tests/aggregate_fuzz.rs index 14cf46962453..74370049e81f 100644 --- a/datafusion/core/tests/aggregate_fuzz.rs +++ b/datafusion/core/tests/aggregate_fuzz.rs @@ -113,6 +113,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str group_by.clone(), aggregate_expr.clone(), vec![None], + vec![None], running_source, schema.clone(), ) @@ -125,6 +126,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str group_by.clone(), aggregate_expr.clone(), vec![None], + vec![None], usual_source, schema.clone(), ) diff --git a/datafusion/core/tests/sqllogictests/test_files/explain.slt b/datafusion/core/tests/sqllogictests/test_files/explain.slt index c64eaf62d3bf..7656dcea635f 100644 --- a/datafusion/core/tests/sqllogictests/test_files/explain.slt +++ b/datafusion/core/tests/sqllogictests/test_files/explain.slt @@ -147,8 +147,7 @@ analyzed_logical_plan SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE -logical_plan after decorrelate_where_exists SAME TEXT AS ABOVE -logical_plan after decorrelate_where_in SAME TEXT AS ABOVE +logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE @@ -176,8 +175,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE -logical_plan after decorrelate_where_exists SAME TEXT AS ABOVE -logical_plan after decorrelate_where_in SAME TEXT AS ABOVE +logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE diff --git a/datafusion/core/tests/sqllogictests/test_files/groupby.slt b/datafusion/core/tests/sqllogictests/test_files/groupby.slt index b9d2543e11b0..e5a93e709fa2 100644 --- a/datafusion/core/tests/sqllogictests/test_files/groupby.slt +++ b/datafusion/core/tests/sqllogictests/test_files/groupby.slt @@ -2010,3 +2010,208 @@ SELECT a, d, statement ok drop table annotated_data_infinite2; + +# create a table for testing +statement ok +CREATE TABLE sales_global (zip_code INT, + country VARCHAR(3), + sn INT, + ts TIMESTAMP, + currency VARCHAR(3), + amount FLOAT + ) as VALUES + (0, 'GRC', 0, '2022-01-01 06:00:00'::timestamp, 'EUR', 30.0), + (1, 'FRA', 1, '2022-01-01 08:00:00'::timestamp, 'EUR', 50.0), + (1, 'TUR', 2, '2022-01-01 11:30:00'::timestamp, 'TRY', 75.0), + (1, 'FRA', 3, '2022-01-02 12:00:00'::timestamp, 'EUR', 200.0), + (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0), + (0, 'GRC', 4, '2022-01-03 10:00:00'::timestamp, 'EUR', 80.0) + +# test_ordering_sensitive_aggregation +# ordering sensitive requirement should add a SortExec in the final plan. To satisfy amount ASC +# in the aggregation +statement ok +set datafusion.execution.target_partitions = 1; + +query TT +EXPLAIN SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts + FROM sales_global + GROUP BY country +---- +logical_plan +Projection: sales_global.country, ARRAYAGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAYAGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +----TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAYAGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(sales_global.amount)] +----SortExec: expr=[amount@1 ASC NULLS LAST] +------MemoryExec: partitions=1, partition_sizes=[1] + + +query T? +SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts + FROM sales_global + GROUP BY country +---- +GRC [30.0, 80.0] +FRA [50.0, 200.0] +TUR [75.0, 100.0] + +# test_ordering_sensitive_aggregation2 +# We should be able to satisfy the finest requirement among all aggregators, when we have multiple aggregators. +# Hence final plan should have SortExec: expr=[amount@1 DESC] to satisfy array_agg requirement. +query TT +EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM sales_global AS s + GROUP BY s.country +---- +logical_plan +Projection: s.country, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] +----SubqueryAlias: s +------TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(s.amount), SUM(s.amount)] +----SortExec: expr=[amount@1 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?R +SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM sales_global AS s + GROUP BY s.country +---- +FRA [200.0, 50.0] 250 +TUR [100.0, 75.0] 175 +GRC [80.0, 30.0] 110 + +# test_ordering_sensitive_multiple_req +# Currently we do not support multiple ordering requirement for aggregation +# once this support is added. This test should change +# See issue: https://github.com/sqlparser-rs/sqlparser-rs/issues/875 +statement error DataFusion error: SQL error: ParserError\("Expected \), found: ,"\) +SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC, s.country DESC) AS amounts, + SUM(s.amount ORDER BY s.amount DESC) AS sum1 + FROM sales_global AS s + GROUP BY s.country + +# test_ordering_sensitive_aggregation3 +# When different aggregators have conflicting requirements, we cannot satisfy all of them in current implementation. +# test below should raise Plan Error. +statement error DataFusion error: This feature is not implemented: Conflicting ordering requirements in aggregate functions is not supported +SELECT ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + ARRAY_AGG(s.amount ORDER BY s.amount ASC) AS amounts2, + ARRAY_AGG(s.amount ORDER BY s.sn ASC) AS amounts3 + FROM sales_global AS s + GROUP BY s.country + +# test_ordering_sensitive_aggregation4 +# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should append requirement to +# the existing ordering. This enables us to still work with bounded memory, and also satisfy aggregation requirement. +# This test checks for whether we can satisfy aggregation requirement in FullyOrdered mode. +query TT +EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +logical_plan +Projection: s.country, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] +----SubqueryAlias: s +------Sort: sales_global.country ASC NULLS LAST +--------TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered +----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?R +SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +FRA [200.0, 50.0] 250 +GRC [80.0, 30.0] 110 +TUR [100.0, 75.0] 175 + +# test_ordering_sensitive_aggregation5 +# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should be append requirement to +# the existing ordering. This enables us to still work with bounded memory, and also satisfy aggregation requirement. +# This test checks for whether we can satisfy aggregation requirement in PartiallyOrdered mode. +query TT +EXPLAIN SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country, s.zip_code +---- +logical_plan +Projection: s.country, s.zip_code, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] +----SubqueryAlias: s +------Sort: sales_global.country ASC NULLS LAST +--------TableScan: sales_global projection=[zip_code, country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, SUM(s.amount)@3 as sum1] +--AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAYAGG(s.amount), SUM(s.amount)], ordering_mode=PartiallyOrdered +----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC] +------MemoryExec: partitions=1, partition_sizes=[1] + +query TI?R +SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country, s.zip_code +---- +FRA 1 [200.0, 50.0] 250 +GRC 0 [80.0, 30.0] 110 +TUR 1 [100.0, 75.0] 175 + +# test_ordering_sensitive_aggregation6 +# If aggregators can work with bounded memory (FullyOrdered or PartiallyOrdered mode), we should be append requirement to +# the existing ordering. When group by expressions contain aggregation requirement, we shouldn't append redundant expression. +# Hence in the final plan SortExec should be SortExec: expr=[country@0 DESC] not SortExec: expr=[country@0 ASC NULLS LAST,country@0 DESC] +query TT +EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +logical_plan +Projection: s.country, ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], SUM(s.amount)]] +----SubqueryAlias: s +------Sort: sales_global.country ASC NULLS LAST +--------TableScan: sales_global projection=[country, amount] +physical_plan +ProjectionExec: expr=[country@0 as country, ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered +----SortExec: expr=[country@0 ASC NULLS LAST] +------MemoryExec: partitions=1, partition_sizes=[1] + +query T?R +SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, + SUM(s.amount) AS sum1 + FROM (SELECT * + FROM sales_global + ORDER BY country) AS s + GROUP BY s.country +---- +FRA [200.0, 50.0] 250 +GRC [80.0, 30.0] 110 +TUR [100.0, 75.0] 175 diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 55bc71f1ab61..230e2fb916fa 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,9 +27,7 @@ use crate::window_frame; use crate::window_function; use crate::Operator; use arrow::datatypes::DataType; -use datafusion_common::Result; -use datafusion_common::{plan_err, Column}; -use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; @@ -424,6 +422,8 @@ pub struct AggregateFunction { pub distinct: bool, /// Optional filter pub filter: Option>, + /// Optional ordering + pub order_by: Option>, } impl AggregateFunction { @@ -432,12 +432,14 @@ impl AggregateFunction { args: Vec, distinct: bool, filter: Option>, + order_by: Option>, ) -> Self { Self { fun, args, distinct, filter, + order_by, } } } @@ -500,6 +502,8 @@ pub struct AggregateUDF { pub args: Vec, /// Optional filter pub filter: Option>, + /// Optional ORDER BY applied prior to aggregating + pub order_by: Option>, } impl AggregateUDF { @@ -508,8 +512,14 @@ impl AggregateUDF { fun: Arc, args: Vec, filter: Option>, + order_by: Option>, ) -> Self { - Self { fun, args, filter } + Self { + fun, + args, + filter, + order_by, + } } } @@ -1042,24 +1052,32 @@ impl fmt::Debug for Expr { distinct, ref args, filter, + order_by, .. }) => { fmt_function(f, &fun.to_string(), *distinct, args, true)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } + if let Some(ob) = order_by { + write!(f, " ORDER BY {:?}", ob)?; + } Ok(()) } Expr::AggregateUDF(AggregateUDF { fun, ref args, filter, + order_by, .. }) => { fmt_function(f, &fun.name, false, args, false)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } + if let Some(ob) = order_by { + write!(f, " ORDER BY {:?}", ob)?; + } Ok(()) } Expr::Between(Between { @@ -1398,25 +1416,35 @@ fn create_name(e: &Expr) -> Result { distinct, args, filter, + order_by, }) => { - let name = create_function_name(&fun.to_string(), *distinct, args)?; + let mut name = create_function_name(&fun.to_string(), *distinct, args)?; if let Some(fe) = filter { - Ok(format!("{name} FILTER (WHERE {fe})")) - } else { - Ok(name) - } + name = format!("{name} FILTER (WHERE {fe})"); + }; + if let Some(order_by) = order_by { + name = format!("{name} ORDER BY {order_by:?}"); + }; + Ok(name) } - Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => { + Expr::AggregateUDF(AggregateUDF { + fun, + args, + filter, + order_by, + }) => { let mut names = Vec::with_capacity(args.len()); for e in args { names.push(create_name(e)?); } - let filter = if let Some(fe) = filter { - format!(" FILTER (WHERE {fe})") - } else { - "".to_string() - }; - Ok(format!("{}({}){}", fun.name, names.join(","), filter)) + let mut info = String::new(); + if let Some(fe) = filter { + info += &format!(" FILTER (WHERE {fe})"); + } + if let Some(ob) = order_by { + info += &format!(" ORDER BY ({:?})", ob); + } + Ok(format!("{}({}){}", fun.name, names.join(","), info)) } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1d781e6f0bf3..6aa51be566bb 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -108,6 +108,7 @@ pub fn min(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } @@ -118,6 +119,7 @@ pub fn max(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } @@ -128,6 +130,7 @@ pub fn sum(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } @@ -138,6 +141,7 @@ pub fn avg(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } @@ -148,6 +152,7 @@ pub fn count(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } @@ -203,6 +208,7 @@ pub fn count_distinct(expr: Expr) -> Expr { vec![expr], true, None, + None, )) } @@ -254,6 +260,7 @@ pub fn approx_distinct(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } @@ -264,6 +271,7 @@ pub fn median(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } @@ -274,6 +282,7 @@ pub fn approx_median(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } @@ -284,6 +293,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { vec![expr, percentile], false, None, + None, )) } @@ -298,6 +308,7 @@ pub fn approx_percentile_cont_with_weight( vec![expr, weight_expr, percentile], false, None, + None, )) } @@ -367,6 +378,7 @@ pub fn stddev(expr: Expr) -> Expr { vec![expr], false, None, + None, )) } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 3b8df59dab7b..3c9884410912 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -97,13 +97,16 @@ impl TreeNode for Expr { } expr_vec } - Expr::AggregateFunction(AggregateFunction { args, filter, .. }) - | Expr::AggregateUDF(AggregateUDF { args, filter, .. }) => { + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) + | Expr::AggregateUDF(AggregateUDF { args, filter, order_by, .. }) => { let mut expr_vec = args.clone(); if let Some(f) = filter { expr_vec.push(f.as_ref().clone()); } + if let Some(o) = order_by { + expr_vec.extend(o.clone()); + } expr_vec } @@ -292,11 +295,13 @@ impl TreeNode for Expr { fun, distinct, filter, + order_by, }) => Expr::AggregateFunction(AggregateFunction::new( fun, transform_vec(args, &mut transform)?, distinct, transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, )), Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( @@ -314,11 +319,22 @@ impl TreeNode for Expr { )) } }, - Expr::AggregateUDF(AggregateUDF { args, fun, filter }) => { + Expr::AggregateUDF(AggregateUDF { + args, + fun, + filter, + order_by, + }) => { + let order_by = if let Some(order_by) = order_by { + Some(transform_vec(order_by, &mut transform)?) + } else { + None + }; Expr::AggregateUDF(AggregateUDF::new( fun, transform_vec(args, &mut transform)?, transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, )) } Expr::InList(InList { @@ -371,6 +387,21 @@ where .transpose() } +/// &mut transform a Option<`Vec` of `Expr`s> +fn transform_option_vec( + option_box: Option>, + transform: &mut F, +) -> Result>> +where + F: FnMut(Expr) -> Result, +{ + Ok(if let Some(exprs) = option_box { + Some(transform_vec(exprs, transform)?) + } else { + None + }) +} + /// &mut transform a `Vec` of `Expr`s fn transform_vec(v: Vec, transform: &mut F) -> Result> where diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index d681390d27cc..6c3690e283d2 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -90,6 +90,7 @@ impl AggregateUDF { fun: Arc::new(self.clone()), args, filter: None, + order_by: None, }) } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 3b0e334618a9..0e2689c14fa1 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -168,12 +168,14 @@ impl TreeNodeRewriter for CountWildcardRewriter { args, distinct, filter, + order_by, }) if args.len() == 1 => match args[0] { Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { fun: aggregate_function::AggregateFunction::Count, args: vec![lit(COUNT_STAR_EXPANSION)], distinct, filter, + order_by, }), _ => old_expr, }, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 6de095cccdd8..fbb61fb1a31a 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -393,6 +393,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { args, distinct, filter, + order_by, }) => { let new_expr = coerce_agg_exprs_for_signature( &fun, @@ -401,18 +402,24 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &aggregate_function::signature(&fun), )?; let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, + fun, new_expr, distinct, filter, order_by, )); Ok(expr) } - Expr::AggregateUDF(expr::AggregateUDF { fun, args, filter }) => { + Expr::AggregateUDF(expr::AggregateUDF { + fun, + args, + filter, + order_by, + }) => { let new_expr = coerce_arguments_for_signature( args.as_slice(), &self.schema, &fun.signature, )?; - let expr = - Expr::AggregateUDF(expr::AggregateUDF::new(fun, new_expr, filter)); + let expr = Expr::AggregateUDF(expr::AggregateUDF::new( + fun, new_expr, filter, order_by, + )); Ok(expr) } Expr::WindowFunction(WindowFunction { @@ -885,6 +892,7 @@ mod test { Arc::new(my_avg), vec![lit(10i64)], None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; @@ -915,6 +923,7 @@ mod test { Arc::new(my_avg), vec![lit("10")], None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") @@ -936,6 +945,7 @@ mod test { vec![lit(12i64)], false, None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: AVG(Int64(12))\n EmptyRelation"; @@ -948,6 +958,7 @@ mod test { vec![col("a")], false, None, + None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: AVG(a)\n EmptyRelation"; @@ -964,6 +975,7 @@ mod test { vec![lit("1")], false, None, + None, )); let err = Projection::try_new(vec![agg_expr], empty).err().unwrap(); assert_eq!( diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index a2db93330918..6989ca535240 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -880,6 +880,7 @@ mod test { )), vec![inner], None, + None, )) }; diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 5c6f10825a86..c64dfc578b96 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -1056,6 +1056,7 @@ mod tests { vec![col("b")], false, Some(Box::new(col("c").gt(lit(42)))), + None, )); let plan = LogicalPlanBuilder::from(table_scan) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index cee31b5b3352..ba7e89094b0f 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -131,6 +131,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { fun, args, filter, + order_by, .. }) => { // is_single_distinct_agg ensure args.len=1 @@ -144,6 +145,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { vec![col(SINGLE_DISTINCT_ALIAS)], false, // intentional to remove distinct here filter.clone(), + order_by.clone(), ))) } _ => Ok(aggr_expr.clone()), @@ -399,6 +401,7 @@ mod tests { vec![col("b")], true, None, + None, )), ], )? diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 967ecdb40f06..a8a0625ca019 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -304,7 +304,7 @@ pub fn ordering_satisfy< /// Checks whether the required [`PhysicalSortExpr`]s are satisfied by the /// provided [`PhysicalSortExpr`]s. -fn ordering_satisfy_concrete< +pub fn ordering_satisfy_concrete< F: FnOnce() -> EquivalenceProperties, F2: FnOnce() -> OrderingEquivalenceProperties, >( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 08e6360aa30e..ef1e2f284e8a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -576,12 +576,14 @@ message AggregateExprNode { repeated LogicalExprNode expr = 2; bool distinct = 3; LogicalExprNode filter = 4; + repeated LogicalExprNode order_by = 5; } message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; LogicalExprNode filter = 3; + repeated LogicalExprNode order_by = 4; } message ScalarUDFExprNode { @@ -1284,6 +1286,10 @@ message MaybeFilter { PhysicalExprNode expr = 1; } +message MaybePhysicalSortExprs { + repeated PhysicalSortExprNode sort_expr = 1; +} + message AggregateExecNode { repeated PhysicalExprNode group_expr = 1; repeated PhysicalExprNode aggr_expr = 2; @@ -1296,6 +1302,7 @@ message AggregateExecNode { repeated PhysicalExprNode null_expr = 8; repeated bool groups = 9; repeated MaybeFilter filter_expr = 10; + repeated MaybePhysicalSortExprs order_by_expr = 11; } message GlobalLimitExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 79f4f06d37d8..c217c4d7fea5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -36,6 +36,9 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { len += 1; } + if !self.order_by_expr.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExecNode", len)?; if !self.group_expr.is_empty() { struct_ser.serialize_field("groupExpr", &self.group_expr)?; @@ -69,6 +72,9 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { struct_ser.serialize_field("filterExpr", &self.filter_expr)?; } + if !self.order_by_expr.is_empty() { + struct_ser.serialize_field("orderByExpr", &self.order_by_expr)?; + } struct_ser.end() } } @@ -96,6 +102,8 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "groups", "filter_expr", "filterExpr", + "order_by_expr", + "orderByExpr", ]; #[allow(clippy::enum_variant_names)] @@ -110,6 +118,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { NullExpr, Groups, FilterExpr, + OrderByExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -141,6 +150,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "nullExpr" | "null_expr" => Ok(GeneratedField::NullExpr), "groups" => Ok(GeneratedField::Groups), "filterExpr" | "filter_expr" => Ok(GeneratedField::FilterExpr), + "orderByExpr" | "order_by_expr" => Ok(GeneratedField::OrderByExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -170,6 +180,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut null_expr__ = None; let mut groups__ = None; let mut filter_expr__ = None; + let mut order_by_expr__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::GroupExpr => { @@ -232,6 +243,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { } filter_expr__ = Some(map.next_value()?); } + GeneratedField::OrderByExpr => { + if order_by_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("orderByExpr")); + } + order_by_expr__ = Some(map.next_value()?); + } } } Ok(AggregateExecNode { @@ -245,6 +262,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { null_expr: null_expr__.unwrap_or_default(), groups: groups__.unwrap_or_default(), filter_expr: filter_expr__.unwrap_or_default(), + order_by_expr: order_by_expr__.unwrap_or_default(), }) } } @@ -271,6 +289,9 @@ impl serde::Serialize for AggregateExprNode { if self.filter.is_some() { len += 1; } + if !self.order_by.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExprNode", len)?; if self.aggr_function != 0 { let v = AggregateFunction::from_i32(self.aggr_function) @@ -286,6 +307,9 @@ impl serde::Serialize for AggregateExprNode { if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; } + if !self.order_by.is_empty() { + struct_ser.serialize_field("orderBy", &self.order_by)?; + } struct_ser.end() } } @@ -301,6 +325,8 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { "expr", "distinct", "filter", + "order_by", + "orderBy", ]; #[allow(clippy::enum_variant_names)] @@ -309,6 +335,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { Expr, Distinct, Filter, + OrderBy, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -334,6 +361,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { "expr" => Ok(GeneratedField::Expr), "distinct" => Ok(GeneratedField::Distinct), "filter" => Ok(GeneratedField::Filter), + "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -357,6 +385,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { let mut expr__ = None; let mut distinct__ = None; let mut filter__ = None; + let mut order_by__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::AggrFunction => { @@ -383,6 +412,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { } filter__ = map.next_value()?; } + GeneratedField::OrderBy => { + if order_by__.is_some() { + return Err(serde::de::Error::duplicate_field("orderBy")); + } + order_by__ = Some(map.next_value()?); + } } } Ok(AggregateExprNode { @@ -390,6 +425,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { expr: expr__.unwrap_or_default(), distinct: distinct__.unwrap_or_default(), filter: filter__, + order_by: order_by__.unwrap_or_default(), }) } } @@ -758,6 +794,9 @@ impl serde::Serialize for AggregateUdfExprNode { if self.filter.is_some() { len += 1; } + if !self.order_by.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateUDFExprNode", len)?; if !self.fun_name.is_empty() { struct_ser.serialize_field("funName", &self.fun_name)?; @@ -768,6 +807,9 @@ impl serde::Serialize for AggregateUdfExprNode { if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; } + if !self.order_by.is_empty() { + struct_ser.serialize_field("orderBy", &self.order_by)?; + } struct_ser.end() } } @@ -782,6 +824,8 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "funName", "args", "filter", + "order_by", + "orderBy", ]; #[allow(clippy::enum_variant_names)] @@ -789,6 +833,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { FunName, Args, Filter, + OrderBy, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -813,6 +858,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "funName" | "fun_name" => Ok(GeneratedField::FunName), "args" => Ok(GeneratedField::Args), "filter" => Ok(GeneratedField::Filter), + "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -835,6 +881,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { let mut fun_name__ = None; let mut args__ = None; let mut filter__ = None; + let mut order_by__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::FunName => { @@ -855,12 +902,19 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { } filter__ = map.next_value()?; } + GeneratedField::OrderBy => { + if order_by__.is_some() { + return Err(serde::de::Error::duplicate_field("orderBy")); + } + order_by__ = Some(map.next_value()?); + } } } Ok(AggregateUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), filter: filter__, + order_by: order_by__.unwrap_or_default(), }) } } @@ -11499,6 +11553,98 @@ impl<'de> serde::Deserialize<'de> for MaybeFilter { deserializer.deserialize_struct("datafusion.MaybeFilter", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for MaybePhysicalSortExprs { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.sort_expr.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.MaybePhysicalSortExprs", len)?; + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "sort_expr", + "sortExpr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + SortExpr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = MaybePhysicalSortExprs; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.MaybePhysicalSortExprs") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut sort_expr__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); + } + sort_expr__ = Some(map.next_value()?); + } + } + } + Ok(MaybePhysicalSortExprs { + sort_expr: sort_expr__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for NegativeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0736b750842a..83255797faca 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -692,6 +692,8 @@ pub struct AggregateExprNode { pub distinct: bool, #[prost(message, optional, boxed, tag = "4")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "5")] + pub order_by: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -702,6 +704,8 @@ pub struct AggregateUdfExprNode { pub args: ::prost::alloc::vec::Vec, #[prost(message, optional, boxed, tag = "3")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "4")] + pub order_by: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1820,6 +1824,12 @@ pub struct MaybeFilter { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct MaybePhysicalSortExprs { + #[prost(message, repeated, tag = "1")] + pub sort_expr: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateExecNode { #[prost(message, repeated, tag = "1")] pub group_expr: ::prost::alloc::vec::Vec, @@ -1842,6 +1852,8 @@ pub struct AggregateExecNode { pub groups: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "10")] pub filter_expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "11")] + pub order_by_expr: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index de4b03a069db..b40f867d98ef 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -990,6 +990,7 @@ pub fn parse_expr( .collect::, _>>()?, expr.distinct, parse_optional_expr(expr.filter.as_deref(), registry)?.map(Box::new), + parse_vec_expr(&expr.order_by, registry)?, ))) } ExprType::Alias(alias) => Ok(Expr::Alias( @@ -1390,6 +1391,7 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), + parse_vec_expr(&pb.order_by, registry)?, ))) } @@ -1479,6 +1481,20 @@ pub fn from_proto_binary_op(op: &str) -> Result { } } +fn parse_vec_expr( + p: &[protobuf::LogicalExprNode], + registry: &dyn FunctionRegistry, +) -> Result>, Error> { + let res = p + .iter() + .map(|elem| { + parse_expr(elem, registry).map_err(|e| DataFusionError::Plan(e.to_string())) + }) + .collect::>>()?; + // Convert empty vector to None. + Ok((!res.is_empty()).then_some(res)) +} + fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, registry: &dyn FunctionRegistry, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 8d68a958b2dd..c61f90b2dc09 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -2525,6 +2525,7 @@ mod roundtrip_tests { vec![col("bananas")], false, None, + None, )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2537,6 +2538,7 @@ mod roundtrip_tests { vec![col("bananas")], true, None, + None, )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2549,6 +2551,7 @@ mod roundtrip_tests { vec![col("bananas"), lit(0.42_f32)], false, None, + None, )); let ctx = SessionContext::new(); @@ -2606,6 +2609,7 @@ mod roundtrip_tests { Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], Some(Box::new(lit(true))), + None, )); let ctx = SessionContext::new(); diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 0ffc893071ee..06156c9f40bb 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,6 +622,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref args, ref distinct, ref filter, + ref order_by, }) => { let aggr_function = match fun { AggregateFunction::ApproxDistinct => { @@ -679,6 +680,13 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Some(e) => Some(Box::new(e.as_ref().try_into()?)), None => None, }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, }; Self { expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), @@ -714,7 +722,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .collect::, Error>>()?, })), }, - Expr::AggregateUDF(expr::AggregateUDF { fun, args, filter }) => Self { + Expr::AggregateUDF(expr::AggregateUDF { + fun, + args, + filter, + order_by, + }) => Self { expr_type: Some(ExprType::AggregateUdfExpr(Box::new( protobuf::AggregateUdfExprNode { fun_name: fun.name.clone(), @@ -727,6 +740,13 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Some(e) => Some(Box::new(e.as_ref().try_into()?)), None => None, }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, }, ))), }, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 8e9ba1f6055c..1290db075259 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -61,6 +61,31 @@ impl From<&protobuf::PhysicalColumn> for Column { } } +/// Parses a physical sort expression from a protobuf. +/// +/// # Arguments +/// +/// * `proto` - Input proto with physical sort expression node +/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `input_schema` - The Arrow schema for the input, used for determining expression data types +/// when performing type coercion. +pub fn parse_physical_sort_expr( + proto: &protobuf::PhysicalSortExprNode, + registry: &dyn FunctionRegistry, + input_schema: &Schema, +) -> Result { + if let Some(expr) = &proto.expr { + let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let options = SortOptions { + descending: !proto.asc, + nulls_first: proto.nulls_first, + }; + Ok(PhysicalSortExpr { expr, options }) + } else { + Err(proto_error("Unexpected empty physical expression")) + } +} + /// Parses a physical expression from a protobuf. /// /// # Arguments diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index cd8950e97bf8..94e26f2cf2b4 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -54,7 +54,7 @@ use prost::Message; use crate::common::proto_error; use crate::common::{csv_delimiter_to_string, str_to_byte}; use crate::physical_plan::from_proto::{ - parse_physical_expr, parse_protobuf_file_scan_config, + parse_physical_expr, parse_physical_sort_expr, parse_protobuf_file_scan_config, }; use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; @@ -409,13 +409,25 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter_expr .iter() .map(|expr| { - let x = expr - .expr + expr.expr .as_ref() - .map(|e| parse_physical_expr(e, registry, &physical_schema)); - x.transpose() + .map(|e| parse_physical_expr(e, registry, &physical_schema)) + .transpose() }) .collect::, _>>()?; + let physical_order_by_expr = hash_agg + .order_by_expr + .iter() + .map(|expr| { + expr.sort_expr + .iter() + .map(|e| { + parse_physical_sort_expr(e, registry, &physical_schema) + }) + .collect::>>() + .map(|exprs| (!exprs.is_empty()).then_some(exprs)) + }) + .collect::>>()?; let physical_aggr_expr: Vec> = hash_agg .aggr_expr @@ -473,6 +485,7 @@ impl AsExecutionPlan for PhysicalPlanNode { PhysicalGroupBy::new(group_expr, null_expr, groups), physical_aggr_expr, physical_filter_expr, + physical_order_by_expr, input, Arc::new((&input_schema).try_into()?), )?)) @@ -893,6 +906,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| expr.to_owned().try_into()) .collect::>>()?; + let order_by = exec + .order_by_expr() + .iter() + .map(|expr| expr.to_owned().try_into()) + .collect::>>()?; + let agg = exec .aggr_expr() .iter() @@ -942,6 +961,7 @@ impl AsExecutionPlan for PhysicalPlanNode { group_expr_name: group_names, aggr_expr: agg, filter_expr: filter, + order_by_expr: order_by, aggr_expr_name: agg_names, mode: agg_mode as i32, input: Some(Box::new(input)), @@ -1425,6 +1445,7 @@ mod roundtrip_tests { PhysicalGroupBy::new_single(groups.clone()), aggregates.clone(), vec![None], + vec![None], Arc::new(EmptyExec::new(false, schema.clone())), schema, )?)) @@ -1494,6 +1515,7 @@ mod roundtrip_tests { PhysicalGroupBy::new_single(groups.clone()), aggregates.clone(), vec![None], + vec![None], Arc::new(EmptyExec::new(false, schema.clone())), schema, )?), @@ -1707,6 +1729,7 @@ mod roundtrip_tests { PhysicalGroupBy::new_single(groups), aggregates.clone(), vec![None], + vec![None], Arc::new(EmptyExec::new(false, schema.clone())), schema, )?)) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 90260b231fb7..f2f65b89f018 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -46,7 +46,7 @@ use crate::protobuf; use crate::protobuf::{physical_aggregate_expr_node, PhysicalSortExprNode, ScalarValue}; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::expressions::{DateTimeIntervalExpr, GetIndexedFieldExpr}; -use datafusion::physical_expr::ScalarFunctionExpr; +use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::joins::utils::JoinSide; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion_common::{DataFusionError, Result}; @@ -541,3 +541,31 @@ impl TryFrom>> for protobuf::MaybeFilter { } } } + +impl TryFrom>> for protobuf::MaybePhysicalSortExprs { + type Error = DataFusionError; + + fn try_from(sort_exprs: Option>) -> Result { + match sort_exprs { + None => Ok(protobuf::MaybePhysicalSortExprs { sort_expr: vec![] }), + Some(sort_exprs) => Ok(protobuf::MaybePhysicalSortExprs { + sort_expr: sort_exprs + .into_iter() + .map(|sort_expr| sort_expr.try_into()) + .collect::>>()?, + }), + } + } +} + +impl TryFrom for protobuf::PhysicalSortExprNode { + type Error = DataFusionError; + + fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result { + Ok(PhysicalSortExprNode { + expr: Some(Box::new(sort_expr.expr.try_into()?)), + asc: !sort_expr.options.descending, + nulls_first: sort_expr.options.nulls_first, + }) + } +} diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 482cd5f0c400..2bc3125d733d 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -113,7 +113,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let (fun, args) = self.aggregate_fn_to_expr(fun, function.args, schema, planner_context)?; return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, None, + fun, args, distinct, None, None, ))); }; @@ -128,7 +128,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { let args = self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new(fm, args, None))); + return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + fm, args, None, None, + ))); } // Special case arrow_cast (as its type is dependent on its argument value) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a120b3400ad8..b914149cf4e1 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -321,11 +321,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { within_group, } = array_agg; - if let Some(order_by) = order_by { - return Err(DataFusionError::NotImplemented(format!( - "ORDER BY not supported in ARRAY_AGG: {order_by}" - ))); - } + let order_by = if let Some(order_by) = order_by { + // TODO: Once sqlparser supports multiple order by clause, handle it + // see issue: https://github.com/sqlparser-rs/sqlparser-rs/issues/875 + Some(vec![self.order_by_to_sort_expr( + *order_by, + input_schema, + planner_context, + )?]) + } else { + None + }; if let Some(limit) = limit { return Err(DataFusionError::NotImplemented(format!( @@ -341,11 +347,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = vec![self.sql_expr_to_logical_expr(*expr, input_schema, planner_context)?]; + // next, aggregate built-ins let fun = AggregateFunction::ArrayAgg; - Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, None, + fun, args, distinct, None, order_by, ))) } @@ -479,6 +485,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun, args, distinct, + order_by, .. }) => Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, @@ -489,6 +496,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, )?)), + order_by, ))), _ => Err(DataFusionError::Internal( "AggregateExpressionWithFilter expression was not an AggregateFunction" diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 160543360963..dd96f1b2bbea 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -169,6 +169,7 @@ where args, distinct, filter, + order_by, }) => Ok(Expr::AggregateFunction(AggregateFunction::new( fun.clone(), args.iter() @@ -176,6 +177,7 @@ where .collect::>>()?, *distinct, filter.clone(), + order_by.clone(), ))), Expr::WindowFunction(WindowFunction { fun, @@ -198,15 +200,19 @@ where .collect::>>()?, window_frame.clone(), ))), - Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => { - Ok(Expr::AggregateUDF(AggregateUDF::new( - fun.clone(), - args.iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - filter.clone(), - ))) - } + Expr::AggregateUDF(AggregateUDF { + fun, + args, + filter, + order_by, + }) => Ok(Expr::AggregateUDF(AggregateUDF::new( + fun.clone(), + args.iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + filter.clone(), + order_by.clone(), + ))), Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias( Box::new(clone_with_replacement(nested_expr, replacement_fn)?), alias_name.clone(), diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 2b8ffde4229c..24bac58dd523 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -278,6 +278,8 @@ pub async fn from_substrait_rel( input.schema(), extensions, filter, + // TODO: Add parsing of order_by also + None, distinct, ) .await @@ -549,6 +551,7 @@ pub async fn from_substrait_agg_func( input_schema: &DFSchema, extensions: &HashMap, filter: Option>, + order_by: Option>, distinct: bool, ) -> Result> { let mut args: Vec = vec![]; @@ -579,6 +582,7 @@ pub async fn from_substrait_agg_func( args, distinct, filter, + order_by, }))) } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 17d424cea6a0..0523c51cc312 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -415,7 +415,8 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter }) => { + // TODO: Once substrait supports order by, add handling for it. + Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by: _order_by }) => { let mut arguments: Vec = vec![]; for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) }); diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 68c02ef55019..d02c733efc3a 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -194,10 +194,10 @@ sum(expression) ### `array_agg` -Returns an array created from the expression elements. +Returns an array created from the expression elements. If ordering requirement is given, elements are inserted in the order of required ordering. ``` -array_agg(expression) +array_agg(expression [ORDER BY expression]) ``` #### Arguments diff --git a/docs/source/user-guide/sql/select.md b/docs/source/user-guide/sql/select.md index 68be88d7cff3..cb73e0852030 100644 --- a/docs/source/user-guide/sql/select.md +++ b/docs/source/user-guide/sql/select.md @@ -189,6 +189,15 @@ Example: SELECT a, b, MAX(c) FROM table GROUP BY a, b ``` +Some aggregation functions accept optional ordering requirement, such as `ARRAY_AGG`. If a requirement is given, +aggregation is calculated in the order of the requirement. + +Example: + +```sql +SELECT a, b, ARRAY_AGG(c, ORDER BY d) FROM table GROUP BY a, b +``` + ## HAVING clause Example: