diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index fc722e9ac352..39e42bbd6049 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -55,6 +55,7 @@ mod utils; pub use datafusion_expr::AggregateFunction; pub use datafusion_physical_expr::expressions::create_aggregate_expr; +use datafusion_physical_expr::expressions::{ArrayAgg, FirstValue, LastValue}; /// Hash aggregate modes #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -388,6 +389,16 @@ fn get_finest_requirement< Ok(result) } +/// Checks whether the given aggregate expression is order-sensitive. +/// For instance, a `SUM` aggregation doesn't depend on the order of its inputs. +/// However, a `FirstAgg` depends on the input ordering (if the order changes, +/// the first value in the list would change). +fn is_order_sensitive(aggr_expr: &Arc) -> bool { + aggr_expr.as_any().is::() + || aggr_expr.as_any().is::() + || aggr_expr.as_any().is::() +} + impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( @@ -395,7 +406,7 @@ impl AggregateExec { group_by: PhysicalGroupBy, aggr_expr: Vec>, filter_expr: Vec>>, - order_by_expr: Vec>>, + mut order_by_expr: Vec>>, input: Arc, input_schema: SchemaRef, ) -> Result { @@ -413,6 +424,18 @@ impl AggregateExec { // In other modes, all groups are collapsed, therefore their input schema // can not contain expressions in the requirement. if mode == AggregateMode::Partial || mode == AggregateMode::Single { + order_by_expr = aggr_expr + .iter() + .zip(order_by_expr.into_iter()) + .map(|(aggr_expr, fn_reqs)| { + // If aggregation function is ordering sensitive, keep ordering requirement as is; otherwise ignore requirement + if is_order_sensitive(aggr_expr) { + fn_reqs + } else { + None + } + }) + .collect::>(); let requirement = get_finest_requirement( &order_by_expr, || input.equivalence_properties(), diff --git a/datafusion/core/tests/dataframe_functions.rs b/datafusion/core/tests/dataframe_functions.rs index 2f4e4d9d8c98..e1173e1d5c07 100644 --- a/datafusion/core/tests/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe_functions.rs @@ -155,11 +155,11 @@ async fn test_fn_approx_median() -> Result<()> { let expr = approx_median(col("b")); let expected = vec![ - "+----------------------+", - "| APPROXMEDIAN(test.b) |", - "+----------------------+", - "| 10 |", - "+----------------------+", + "+-----------------------+", + "| APPROX_MEDIAN(test.b) |", + "+-----------------------+", + "| 10 |", + "+-----------------------+", ]; let df = create_test_table().await?; @@ -175,11 +175,11 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { let expr = approx_percentile_cont(col("b"), lit(0.5)); let expected = vec![ - "+-------------------------------------------+", - "| APPROXPERCENTILECONT(test.b,Float64(0.5)) |", - "+-------------------------------------------+", - "| 10 |", - "+-------------------------------------------+", + "+---------------------------------------------+", + "| APPROX_PERCENTILE_CONT(test.b,Float64(0.5)) |", + "+---------------------------------------------+", + "| 10 |", + "+---------------------------------------------+", ]; let df = create_test_table().await?; diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index e847ea0c0ebf..3ff81581c096 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -29,7 +29,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { // The results for this query should be something like the following: // +------------------------------------------+ - // | ARRAYAGG(DISTINCT aggregate_test_100.c2) | + // | ARRAY_AGG(DISTINCT aggregate_test_100.c2) | // +------------------------------------------+ // | [4, 2, 3, 5, 1] | // +------------------------------------------+ @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { assert_eq!( *actual[0].schema(), Schema::new(vec![Field::new_list( - "ARRAYAGG(DISTINCT aggregate_test_100.c2)", + "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), false ),]) diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt index 17d89a9f0562..ab3516e9e55b 100644 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt @@ -41,7 +41,7 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' ####### # https://github.com/apache/arrow-datafusion/issues/3353 -statement error DataFusion error: Schema error: Schema contains duplicate unqualified field name "APPROXDISTINCT\(aggregate_test_100\.c9\)" +statement error DataFusion error: Schema error: Schema contains duplicate unqualified field name "APPROX_DISTINCT\(aggregate_test_100\.c9\)" SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight diff --git a/datafusion/core/tests/sqllogictests/test_files/groupby.slt b/datafusion/core/tests/sqllogictests/test_files/groupby.slt index b81731e64c98..8bac60bbba8a 100644 --- a/datafusion/core/tests/sqllogictests/test_files/groupby.slt +++ b/datafusion/core/tests/sqllogictests/test_files/groupby.slt @@ -1974,25 +1974,26 @@ query III # test_source_sorted_groupby2 - +# If ordering is not important for the aggregation function, we should ignore the ordering requirement. Hence +# "ORDER BY a DESC" should have no effect. query TT EXPLAIN SELECT a, d, - SUM(c) as summation1 + SUM(c ORDER BY a DESC) as summation1 FROM annotated_data_infinite2 GROUP BY d, a ---- logical_plan -Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.c) AS summation1 ---Aggregate: groupBy=[[annotated_data_infinite2.d, annotated_data_infinite2.a]], aggr=[[SUM(annotated_data_infinite2.c)]] +Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS summation1 +--Aggregate: groupBy=[[annotated_data_infinite2.d, annotated_data_infinite2.a]], aggr=[[SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] ----TableScan: annotated_data_infinite2 projection=[a, c, d] physical_plan -ProjectionExec: expr=[a@1 as a, d@0 as d, SUM(annotated_data_infinite2.c)@2 as summation1] +ProjectionExec: expr=[a@1 as a, d@0 as d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as summation1] --AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=PartiallyOrdered ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST], has_header=true query III SELECT a, d, - SUM(c) as summation1 + SUM(c ORDER BY a DESC) as summation1 FROM annotated_data_infinite2 GROUP BY d, a ---- @@ -2007,6 +2008,85 @@ SELECT a, d, 1 4 913 1 2 848 +# test_source_sorted_groupby3 + +query TT +EXPLAIN SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +logical_plan +Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS first_c +--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] +----TableScan: annotated_data_infinite2 projection=[a, b, c] +physical_plan +ProjectionExec: expr=[a@0 as a, b@1 as b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as first_c] +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[FIRST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true + +query III +SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +0 0 0 +0 1 25 +1 2 50 +1 3 75 + +# test_source_sorted_groupby4 + +query TT +EXPLAIN SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +logical_plan +Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS last_c +--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] +----TableScan: annotated_data_infinite2 projection=[a, b, c] +physical_plan +ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as last_c] +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true + +query III +SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +0 0 24 +0 1 49 +1 2 74 +1 3 99 + +# when LAST_VALUE, or FIRST_VALUE value do not contain ordering requirement +# queries should still work, However, result depends on the scanning order and +# not deterministic +query TT +EXPLAIN SELECT a, b, LAST_VALUE(c) as last_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +logical_plan +Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(annotated_data_infinite2.c) AS last_c +--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[LAST_VALUE(annotated_data_infinite2.c)]] +----TableScan: annotated_data_infinite2 projection=[a, b, c] +physical_plan +ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c)@2 as last_c] +--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=FullyOrdered +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true + +query III +SELECT a, b, LAST_VALUE(c) as last_c + FROM annotated_data_infinite2 + GROUP BY a, b +---- +0 0 24 +0 1 49 +1 2 74 +1 3 99 + statement ok drop table annotated_data_infinite2; @@ -2038,12 +2118,12 @@ EXPLAIN SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts GROUP BY country ---- logical_plan -Projection: sales_global.country, ARRAYAGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts ---Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAYAGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] ----TableScan: sales_global projection=[country, amount] physical_plan -ProjectionExec: expr=[country@0 as country, ARRAYAGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(sales_global.amount)] +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] ----SortExec: expr=[amount@1 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2067,13 +2147,13 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, GROUP BY s.country ---- logical_plan -Projection: s.country, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] +Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] ----SubqueryAlias: s ------TableScan: sales_global projection=[country, amount] physical_plan -ProjectionExec: expr=[country@0 as country, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(s.amount), SUM(s.amount)] +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)] ----SortExec: expr=[amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2120,14 +2200,14 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, GROUP BY s.country ---- logical_plan -Projection: s.country, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] +Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] ----SubqueryAlias: s ------Sort: sales_global.country ASC NULLS LAST --------TableScan: sales_global projection=[country, amount] physical_plan -ProjectionExec: expr=[country@0 as country, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered ----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2156,14 +2236,14 @@ EXPLAIN SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) GROUP BY s.country, s.zip_code ---- logical_plan -Projection: s.country, s.zip_code, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] +Projection: s.country, s.zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)]] ----SubqueryAlias: s ------Sort: sales_global.country ASC NULLS LAST --------TableScan: sales_global projection=[zip_code, country, amount] physical_plan -ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAYAGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, SUM(s.amount)@3 as sum1] ---AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAYAGG(s.amount), SUM(s.amount)], ordering_mode=PartiallyOrdered +ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, SUM(s.amount)@3 as sum1] +--AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=PartiallyOrdered ----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2192,14 +2272,14 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC) AS amounts GROUP BY s.country ---- logical_plan -Projection: s.country, ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], SUM(s.amount)]] +Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], SUM(s.amount)]] ----SubqueryAlias: s ------Sort: sales_global.country ASC NULLS LAST --------TableScan: sales_global projection=[country, amount] physical_plan -ProjectionExec: expr=[country@0 as country, ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered ----SortExec: expr=[country@0 ASC NULLS LAST] ------MemoryExec: partitions=1, partition_sizes=[1] @@ -2227,14 +2307,14 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount D GROUP BY s.country ---- logical_plan -Projection: s.country, ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 ---Aggregate: groupBy=[[s.country]], aggr=[[ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], SUM(s.amount)]] +Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 +--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], SUM(s.amount)]] ----SubqueryAlias: s ------Sort: sales_global.country ASC NULLS LAST --------TableScan: sales_global projection=[country, amount] physical_plan -ProjectionExec: expr=[country@0 as country, ARRAYAGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] ---AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAYAGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered +ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] +--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount), SUM(s.amount)], ordering_mode=FullyOrdered ----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] ------MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt b/datafusion/core/tests/sqllogictests/test_files/window.slt index 32f45dbb57b4..8ab9b29da4dc 100644 --- a/datafusion/core/tests/sqllogictests/test_files/window.slt +++ b/datafusion/core/tests/sqllogictests/test_files/window.slt @@ -2007,16 +2007,16 @@ query TT EXPLAIN SELECT ARRAY_AGG(c13) as array_agg1 FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) ---- logical_plan -Projection: ARRAYAGG(aggregate_test_100.c13) AS array_agg1 ---Aggregate: groupBy=[[]], aggr=[[ARRAYAGG(aggregate_test_100.c13)]] +Projection: ARRAY_AGG(aggregate_test_100.c13) AS array_agg1 +--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(aggregate_test_100.c13)]] ----Limit: skip=0, fetch=1 ------Sort: aggregate_test_100.c13 ASC NULLS LAST, fetch=1 --------TableScan: aggregate_test_100 projection=[c13] physical_plan -ProjectionExec: expr=[ARRAYAGG(aggregate_test_100.c13)@0 as array_agg1] ---AggregateExec: mode=Final, gby=[], aggr=[ARRAYAGG(aggregate_test_100.c13)] +ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1] +--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)] ----CoalescePartitionsExec -------AggregateExec: mode=Partial, gby=[], aggr=[ARRAYAGG(aggregate_test_100.c13)] +------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------GlobalLimitExec: skip=0, fetch=1 ------------SortExec: fetch=1, expr=[c13@0 ASC NULLS LAST] @@ -3017,6 +3017,19 @@ SELECT a, b, c, 0 0 3 11 96 11 2 10 36 10 36 11 5 11 9 0 0 4 9 72 9 NULL 14 45 14 45 9 4 9 9 +#fn aggregate order by with window frame +# In window expressions, aggregate functions should not have an ordering requirement, such requirements +# should be defined in the window frame. Therefore, the query below should generate an error. Note that +# PostgreSQL also behaves this way. +statement error DataFusion error: Error during planning: Aggregate ORDER BY is not implemented for window functions +SELECT SUM(b ORDER BY a ASC) OVER() as sum1 + FROM annotated_data_infinite2 + +# Even if, requirement of window clause and aggregate function match; +# we should raise an error, when an ordering requirement is given to aggregate functions in window clauses. +statement error DataFusion error: Error during planning: Aggregate ORDER BY is not implemented for window functions +EXPLAIN SELECT a, b, LAST_VALUE(c ORDER BY a ASC) OVER (order by a ASC) as last_c + FROM annotated_data_infinite2 statement ok drop table annotated_data_finite2 @@ -3135,4 +3148,4 @@ SELECT WINDOW window1 AS (ORDER BY C12), window1 AS (ORDER BY C3) ORDER BY C3 - LIMIT 5 \ No newline at end of file + LIMIT 5 diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 7d5fa277de7b..8258c8b80585 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -42,6 +42,10 @@ pub enum AggregateFunction { ApproxDistinct, /// array_agg ArrayAgg, + /// first_value + FirstValue, + /// last_value + LastValue, /// Variance (Sample) Variance, /// Variance (Population) @@ -76,10 +80,43 @@ pub enum AggregateFunction { BoolOr, } +impl AggregateFunction { + fn name(&self) -> &str { + use AggregateFunction::*; + match self { + Count => "COUNT", + Sum => "SUM", + Min => "MIN", + Max => "MAX", + Avg => "AVG", + Median => "MEDIAN", + ApproxDistinct => "APPROX_DISTINCT", + ArrayAgg => "ARRAY_AGG", + FirstValue => "FIRST_VALUE", + LastValue => "LAST_VALUE", + Variance => "VARIANCE", + VariancePop => "VARIANCE_POP", + Stddev => "STDDEV", + StddevPop => "STDDEV_POP", + Covariance => "COVARIANCE", + CovariancePop => "COVARIANCE_POP", + Correlation => "CORRELATION", + ApproxPercentileCont => "APPROX_PERCENTILE_CONT", + ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", + ApproxMedian => "APPROX_MEDIAN", + Grouping => "GROUPING", + BitAnd => "BIT_AND", + BitOr => "BIT_OR", + BitXor => "BIT_XOR", + BoolAnd => "BOOL_AND", + BoolOr => "BOOL_OR", + } + } +} + impl fmt::Display for AggregateFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - // uppercase of the debug. - write!(f, "{}", format!("{self:?}").to_uppercase()) + write!(f, "{}", self.name()) } } @@ -101,6 +138,8 @@ impl FromStr for AggregateFunction { "min" => AggregateFunction::Min, "sum" => AggregateFunction::Sum, "array_agg" => AggregateFunction::ArrayAgg, + "first_value" => AggregateFunction::FirstValue, + "last_value" => AggregateFunction::LastValue, // statistical "corr" => AggregateFunction::Correlation, "covar" => AggregateFunction::Covariance, @@ -182,6 +221,9 @@ pub fn return_type( Ok(coerced_data_types[0].clone()) } AggregateFunction::Grouping => Ok(DataType::Int32), + AggregateFunction::FirstValue | AggregateFunction::LastValue => { + Ok(coerced_data_types[0].clone()) + } } } @@ -232,7 +274,9 @@ pub fn signature(fun: &AggregateFunction) -> Signature { | AggregateFunction::Stddev | AggregateFunction::StddevPop | AggregateFunction::Median - | AggregateFunction::ApproxMedian => { + | AggregateFunction::ApproxMedian + | AggregateFunction::FirstValue + | AggregateFunction::LastValue => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::Covariance | AggregateFunction::CovariancePop => { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 2cc6c322e6c2..4f02bf3dfd2a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -264,7 +264,9 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::Median => Ok(input_types.to_vec()), + AggregateFunction::Median + | AggregateFunction::FirstValue + | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), } } diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 2d91dca8cc06..1bae3a162e50 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -42,10 +42,15 @@ pub enum WindowFunction { /// Find DataFusion's built-in window function by name. pub fn find_df_window_func(name: &str) -> Option { let name = name.to_lowercase(); - if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Some(WindowFunction::AggregateFunction(aggregate)) - } else if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { + // Code paths for window functions leveraging ordinary aggregators and + // built-in window functions are quite different, and the same function + // may have different implementations for these cases. If the sought + // function is not found among built-in window functions, we search for + // it among aggregate functions. + if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { Some(WindowFunction::BuiltInWindowFunction(built_in_function)) + } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { + Some(WindowFunction::AggregateFunction(aggregate)) } else { None } @@ -53,19 +58,7 @@ pub fn find_df_window_func(name: &str) -> Option { impl fmt::Display for BuiltInWindowFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"), - BuiltInWindowFunction::Rank => write!(f, "RANK"), - BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"), - BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"), - BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"), - BuiltInWindowFunction::Ntile => write!(f, "NTILE"), - BuiltInWindowFunction::Lag => write!(f, "LAG"), - BuiltInWindowFunction::Lead => write!(f, "LEAD"), - BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"), - BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"), - BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"), - } + write!(f, "{}", self.name()) } } @@ -112,6 +105,25 @@ pub enum BuiltInWindowFunction { NthValue, } +impl BuiltInWindowFunction { + fn name(&self) -> &str { + use BuiltInWindowFunction::*; + match self { + RowNumber => "ROW_NUMBER", + Rank => "RANK", + DenseRank => "DENSE_RANK", + PercentRank => "PERCENT_RANK", + CumeDist => "CUME_DIST", + Ntile => "NTILE", + Lag => "LAG", + Lead => "LEAD", + FirstValue => "FIRST_VALUE", + LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", + } + } +} + impl FromStr for BuiltInWindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 2410f0147ef5..69ff89a3929d 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -308,6 +308,16 @@ pub fn create_aggregate_expr( "MEDIAN(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::FirstValue, _) => Arc::new(expressions::FirstValue::new( + input_phy_exprs[0].clone(), + name, + input_phy_types[0].clone(), + )), + (AggregateFunction::LastValue, _) => Arc::new(expressions::LastValue::new( + input_phy_exprs[0].clone(), + name, + input_phy_types[0].clone(), + )), }) } diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs new file mode 100644 index 000000000000..f65360c75199 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the FIRST_VALUE/LAST_VALUE aggregations. + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, PhysicalExpr}; + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use arrow_array::Array; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Accumulator; + +use std::any::Any; +use std::sync::Arc; + +/// FIRST_VALUE aggregate expression +#[derive(Debug)] +pub struct FirstValue { + name: String, + pub data_type: DataType, + expr: Arc, +} + +impl FirstValue { + /// Creates a new FIRST_VALUE aggregation function. + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + data_type, + expr, + } + } +} + +impl AggregateExpr for FirstValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(FirstValueAccumulator::try_new(&self.data_type)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + format_state_name(&self.name, "first_value"), + self.data_type.clone(), + true, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(LastValue::new( + self.expr.clone(), + self.name.clone(), + self.data_type.clone(), + ))) + } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(FirstValueAccumulator::try_new(&self.data_type)?)) + } +} + +impl PartialEq for FirstValue { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +struct FirstValueAccumulator { + first: ScalarValue, +} + +impl FirstValueAccumulator { + /// Creates a new `FirstValueAccumulator` for the given `data_type`. + pub fn try_new(data_type: &DataType) -> Result { + ScalarValue::try_from(data_type).map(|value| Self { first: value }) + } +} + +impl Accumulator for FirstValueAccumulator { + fn state(&self) -> Result> { + Ok(vec![self.first.clone()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // If we have seen first value, we shouldn't update it + let values = &values[0]; + if !values.is_empty() { + self.first = ScalarValue::try_from_array(values, 0)?; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // FIRST_VALUE(first1, first2, first3, ...) + self.update_batch(states) + } + + fn evaluate(&self) -> Result { + Ok(self.first.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.first) + + self.first.size() + } +} + +/// LAST_VALUE aggregate expression +#[derive(Debug)] +pub struct LastValue { + name: String, + pub data_type: DataType, + expr: Arc, +} + +impl LastValue { + /// Creates a new LAST_VALUE aggregation function. + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + data_type, + expr, + } + } +} + +impl AggregateExpr for LastValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(LastValueAccumulator::try_new(&self.data_type)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + format_state_name(&self.name, "last_value"), + self.data_type.clone(), + true, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(FirstValue::new( + self.expr.clone(), + self.name.clone(), + self.data_type.clone(), + ))) + } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(LastValueAccumulator::try_new(&self.data_type)?)) + } +} + +impl PartialEq for LastValue { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +struct LastValueAccumulator { + last: ScalarValue, +} + +impl LastValueAccumulator { + /// Creates a new `LastValueAccumulator` for the given `data_type`. + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + last: ScalarValue::try_from(data_type)?, + }) + } +} + +impl Accumulator for LastValueAccumulator { + fn state(&self) -> Result> { + Ok(vec![self.last.clone()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + if !values.is_empty() { + // Update with last value in the array. + self.last = ScalarValue::try_from_array(values, values.len() - 1)?; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // LAST_VALUE(last1, last2, last3, ...) + self.update_batch(states) + } + + fn evaluate(&self) -> Result { + Ok(self.last.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + self.last.size() + } +} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 34302c5aaf51..8da635cfb2ea 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -37,6 +37,7 @@ pub(crate) mod correlation; pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; +pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; #[macro_use] diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index e65d17fa284a..66d593c5cafa 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -54,6 +54,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; +pub use crate::aggregate::first_last::{FirstValue, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 446b4a027591..7c35452085e3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -570,6 +570,10 @@ enum AggregateFunction { BIT_XOR = 21; BOOL_AND = 22; BOOL_OR = 23; + // When a function with the same name exists among built-in window functions, + // we append "_AGG" to obey name scoping rules. + FIRST_VALUE_AGG = 24; + LAST_VALUE_AGG = 25; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 571fd45f8a6a..6dbe25c2d822 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -463,6 +463,8 @@ impl serde::Serialize for AggregateFunction { Self::BitXor => "BIT_XOR", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", + Self::FirstValueAgg => "FIRST_VALUE_AGG", + Self::LastValueAgg => "LAST_VALUE_AGG", }; serializer.serialize_str(variant) } @@ -498,6 +500,8 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR", "BOOL_AND", "BOOL_OR", + "FIRST_VALUE_AGG", + "LAST_VALUE_AGG", ]; struct GeneratedVisitor; @@ -564,6 +568,8 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR" => Ok(AggregateFunction::BitXor), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), + "FIRST_VALUE_AGG" => Ok(AggregateFunction::FirstValueAgg), + "LAST_VALUE_AGG" => Ok(AggregateFunction::LastValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index a5c0603239f3..7e48db10f3ae 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2400,6 +2400,10 @@ pub enum AggregateFunction { BitXor = 21, BoolAnd = 22, BoolOr = 23, + /// When a function with the same name exists among built-in window functions, + /// we append "_AGG" to obey name scoping rules. + FirstValueAgg = 24, + LastValueAgg = 25, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2434,6 +2438,8 @@ impl AggregateFunction { AggregateFunction::BitXor => "BIT_XOR", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", + AggregateFunction::FirstValueAgg => "FIRST_VALUE_AGG", + AggregateFunction::LastValueAgg => "LAST_VALUE_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2465,6 +2471,8 @@ impl AggregateFunction { "BIT_XOR" => Some(Self::BitXor), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), + "FIRST_VALUE_AGG" => Some(Self::FirstValueAgg), + "LAST_VALUE_AGG" => Some(Self::LastValueAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b40f867d98ef..1150220bef4a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -529,6 +529,8 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::Median => Self::Median, + protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, + protobuf::AggregateFunction::LastValueAgg => Self::LastValue, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 06156c9f40bb..191c49194407 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -388,6 +388,8 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, AggregateFunction::Median => Self::Median, + AggregateFunction::FirstValue => Self::FirstValueAgg, + AggregateFunction::LastValue => Self::LastValueAgg, } } } @@ -667,6 +669,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 0c5b460ead95..70489203b200 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -53,6 +53,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))); }; + // If function is a window function (it has an OVER clause), + // it shouldn't have ordering requirement as function argument + // required ordering should be defined in OVER clause. + if !function.order_by.is_empty() && function.over.is_some() { + return Err(DataFusionError::Plan( + "Aggregate ORDER BY is not implemented for window functions".to_string(), + )); + } + // then, window function if let Some(WindowType::WindowSpec(window)) = function.over.take() { let partition_by = window @@ -107,10 +116,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, aggregate built-ins if let Ok(fun) = AggregateFunction::from_str(&name) { let distinct = function.distinct; + let order_by = + self.order_by_to_sort_expr(&function.order_by, schema, planner_context)?; + let order_by = (!order_by.is_empty()).then_some(order_by); let (fun, args) = self.aggregate_fn_to_expr(fun, function.args, schema, planner_context)?; return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, None, None, + fun, args, distinct, None, order_by, ))); }; diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 452761454afc..d4585c148bbf 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -1495,8 +1495,8 @@ fn select_count_column() { #[test] fn select_approx_median() { let sql = "SELECT approx_median(age) FROM person"; - let expected = "Projection: APPROXMEDIAN(person.age)\ - \n Aggregate: groupBy=[[]], aggr=[[APPROXMEDIAN(person.age)]]\ + let expected = "Projection: APPROX_MEDIAN(person.age)\ + \n Aggregate: groupBy=[[]], aggr=[[APPROX_MEDIAN(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -2427,8 +2427,8 @@ fn approx_median_window() { let sql = "SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) from orders"; let expected = "\ - Projection: orders.order_id, APPROXMEDIAN(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[APPROXMEDIAN(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id, APPROX_MEDIAN(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[APPROX_MEDIAN(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index d02c733efc3a..132ba47e2461 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -36,6 +36,8 @@ Aggregate functions operate on a set of values to compute a single result. - [min](#min) - [sum](#sum) - [array_agg](#array_agg) +- [first_value](#first_value) +- [last_value](#last_value) ### `avg` @@ -202,6 +204,32 @@ array_agg(expression [ORDER BY expression]) #### Arguments +- **expression**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `first_value` + +Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. + +``` +first_value(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `last_value` + +Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. + +``` +last_value(expression [ORDER BY expression]) +``` + +#### Arguments + - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators.