From 88390a0532af85b5874323fb8dae6477e62b085c Mon Sep 17 00:00:00 2001 From: Jacob Sherin Date: Sun, 1 Sep 2024 00:33:56 +0530 Subject: [PATCH 1/2] Makes `nth_value` expression API public --- datafusion/functions-aggregate/src/lib.rs | 1 + .../functions-aggregate/src/nth_value.rs | 28 ++++++++++++++----- .../tests/cases/roundtrip_logical_plan.rs | 19 +++++++++++++ 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b54cd181a0cb..ca0276d326a4 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -113,6 +113,7 @@ pub mod expr_fn { pub use super::median::median; pub use super::min_max::max; pub use super::min_max::min; + pub use super::nth_value::nth_value; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 7425bdfa18e7..424063e8b7d7 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -30,19 +30,33 @@ use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValu use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ReversedUDAF, Signature, Volatility, + lit, Accumulator, AggregateUDFImpl, ExprFunctionExt, ReversedUDAF, Signature, + SortExpr, Volatility, }; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -make_udaf_expr_and_func!( - NthValueAgg, - nth_value, - "Returns the nth value in a group of values.", - nth_value_udaf -); +create_func!(NthValueAgg, nth_value_udaf); + +/// Returns the nth value in a group of values. +pub fn nth_value( + expr: datafusion_expr::Expr, + n: i64, + order_by: Option>, +) -> datafusion_expr::Expr { + let args = vec![expr, lit(n)]; + if let Some(order_by) = order_by { + nth_value_udaf() + .call(args) + .order_by(order_by) + .build() + .unwrap() + } else { + nth_value_udaf().call(args) + } +} /// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e174d1b50713..9407ee603172 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -71,6 +71,7 @@ use datafusion_expr::{ use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, + nth_value, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -903,6 +904,24 @@ async fn roundtrip_expr_api() -> Result<()> { vec![lit(10), lit(20), lit(30)], ), row_number(), + nth_value(col("b"), 1, None), + nth_value( + col("b"), + 1, + Some(vec![ + col("a").sort(false, false), + col("b").sort(true, false), + ]), + ), + nth_value(col("b"), -1, None), + nth_value( + col("b"), + -1, + Some(vec![ + col("a").sort(false, false), + col("b").sort(true, false), + ]), + ), ]; // ensure expressions created with the expr api can be round tripped From 19bf6663a8be1f496cc380545f6fb2acc9c863f1 Mon Sep 17 00:00:00 2001 From: Jacob Sherin Date: Mon, 2 Sep 2024 15:23:23 +0530 Subject: [PATCH 2/2] Updates type of `order_by` parameter --- datafusion/functions-aggregate/src/nth_value.rs | 4 ++-- .../proto/tests/cases/roundtrip_logical_plan.rs | 14 ++++---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 424063e8b7d7..bbfe56914c91 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -44,10 +44,10 @@ create_func!(NthValueAgg, nth_value_udaf); pub fn nth_value( expr: datafusion_expr::Expr, n: i64, - order_by: Option>, + order_by: Vec, ) -> datafusion_expr::Expr { let args = vec![expr, lit(n)]; - if let Some(order_by) = order_by { + if !order_by.is_empty() { nth_value_udaf() .call(args) .order_by(order_by) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9407ee603172..994ed8ad2352 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -904,23 +904,17 @@ async fn roundtrip_expr_api() -> Result<()> { vec![lit(10), lit(20), lit(30)], ), row_number(), - nth_value(col("b"), 1, None), + nth_value(col("b"), 1, vec![]), nth_value( col("b"), 1, - Some(vec![ - col("a").sort(false, false), - col("b").sort(true, false), - ]), + vec![col("a").sort(false, false), col("b").sort(true, false)], ), - nth_value(col("b"), -1, None), + nth_value(col("b"), -1, vec![]), nth_value( col("b"), -1, - Some(vec![ - col("a").sort(false, false), - col("b").sort(true, false), - ]), + vec![col("a").sort(false, false), col("b").sort(true, false)], ), ];