From b627ca3e78d35cd12a850a7ef181fd8862dbf50f Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 14 Jun 2024 07:23:23 +0800 Subject: [PATCH] Remove builtin count (#10893) * rm expr fn Signed-off-by: jayzhan211 * rm function Signed-off-by: jayzhan211 * fix query and fmt Signed-off-by: jayzhan211 * fix example Signed-off-by: jayzhan211 * Update datafusion/expr/src/test/function_stub.rs Co-authored-by: Andrew Lamb --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion/expr/src/aggregate_function.rs | 6 -- datafusion/expr/src/expr.rs | 13 --- datafusion/expr/src/expr_fn.rs | 26 ------ datafusion/expr/src/logical_plan/plan.rs | 4 +- datafusion/expr/src/test/function_stub.rs | 86 ++++++++++++++++++- .../expr/src/type_coercion/aggregates.rs | 2 - datafusion/optimizer/Cargo.toml | 1 + .../src/analyzer/count_wildcard_rule.rs | 42 +++------ datafusion/optimizer/src/decorrelate.rs | 10 +-- .../src/eliminate_group_by_constant.rs | 6 +- .../optimizer/src/optimize_projections/mod.rs | 20 ++--- .../src/single_distinct_to_groupby.rs | 52 +++++------ .../optimizer/tests/optimizer_integration.rs | 8 +- .../physical-expr/src/aggregate/build_in.rs | 19 +--- datafusion/proto/Cargo.toml | 1 + datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 2 - .../tests/cases/roundtrip_logical_plan.rs | 35 +++----- datafusion/sql/examples/sql.rs | 10 ++- datafusion/sql/src/unparser/expr.rs | 45 ++++------ datafusion/sql/src/utils.rs | 3 +- datafusion/sql/tests/cases/plan_to_sql.rs | 6 +- datafusion/sql/tests/common/mod.rs | 3 +- datafusion/sql/tests/sql_integration.rs | 7 +- .../sqllogictest/test_files/functions.slt | 2 +- 28 files changed, 200 insertions(+), 219 deletions(-) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index e3d2e6555d5c..5899cc927703 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -33,8 +33,6 @@ use strum_macros::EnumIter; // https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// Count - Count, /// Minimum Min, /// Maximum @@ -89,7 +87,6 @@ impl AggregateFunction { pub fn name(&self) -> &str { use AggregateFunction::*; match self { - Count => "COUNT", Min => "MIN", Max => "MAX", Avg => "AVG", @@ -135,7 +132,6 @@ impl FromStr for AggregateFunction { "bit_xor" => AggregateFunction::BitXor, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, - "count" => AggregateFunction::Count, "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, "min" => AggregateFunction::Min, @@ -190,7 +186,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::Count => Ok(DataType::Int64), AggregateFunction::Max | AggregateFunction::Min => { // For min and max agg function, the returned type is same as input type. // The coerced_data_types is same with input_types. @@ -249,7 +244,6 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable), AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { Signature::any(1, Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 57f5414c13bd..9ba866a4c919 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2135,18 +2135,6 @@ mod test { use super::*; - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); @@ -2250,7 +2238,6 @@ mod test { "nth_value", "min", "max", - "count", "avg", ]; for name in names { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1fafc63e9665..fb5b3991ecd8 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -192,19 +192,6 @@ pub fn avg(expr: Expr) -> Expr { )) } -/// Create an expression to represent the count() aggregate function -// TODO: Remove this and use `expr_fn::count` instead -pub fn count(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( @@ -250,19 +237,6 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { )) } -/// Create an expression to represent the count(distinct) aggregate function -// TODO: Remove this and use `expr_fn::count_distinct` instead -pub fn count_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - true, - None, - None, - None, - )) -} - /// Create an in_list expression pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { Expr::InList(InList::new(Box::new(expr), list, negated)) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9ea2abe64ede..02378ab3fc1b 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2965,11 +2965,13 @@ mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; + use crate::test::function_stub::count; + fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index b9aa1e636d94..ac98ee9747cc 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -31,7 +31,7 @@ use crate::{ use arrow::datatypes::{ DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { @@ -69,6 +69,19 @@ pub fn sum(expr: Expr) -> Expr { )) } +create_func!(Count, count_udaf); + +pub fn count(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + count_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + /// Stub `sum` used for optimizer testing #[derive(Debug)] pub struct Sum { @@ -189,3 +202,74 @@ impl AggregateUDFImpl for Sum { AggregateOrderSensitivity::Insensitive } } + +/// Testing stub implementation of COUNT aggregate +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index ab7deaff9885..2c76407cdfe2 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -96,7 +96,6 @@ pub fn coerce_types( check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; match agg_fun { - AggregateFunction::Count => Ok(input_types.to_vec()), AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), AggregateFunction::Min | AggregateFunction::Max => { // min and max support the dictionary data type @@ -525,7 +524,6 @@ mod tests { // test count, array_agg, approx_distinct, min, max. // the coerced types is same with input types let funs = vec![ - AggregateFunction::Count, AggregateFunction::ArrayAgg, AggregateFunction::Min, AggregateFunction::Max, diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index cb14f6bdd4a3..1a9e9630c076 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -56,5 +56,6 @@ regex-syntax = "0.8.0" [dev-dependencies] arrow-buffer = { workspace = true } ctor = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index af1c99c52390..de2af520053a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -25,9 +25,7 @@ use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionDefinition, WindowFunction, }; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{ - aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition, -}; +use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool { } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - match aggregate_function { + matches!(aggregate_function, AggregateFunction { func_def: AggregateFunctionDefinition::UDF(udf), args, .. - } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true, - AggregateFunction { - func_def: - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ), - args, - .. - } if args.len() == 1 && is_wildcard(&args[0]) => true, - _ => false, - } + } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { let args = &window_function.args; - match window_function.fun { - WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ) if args.len() == 1 && is_wildcard(&args[0]) => true, + matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => - { - true - } - _ => false, - } + if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn analyze_internal(plan: LogicalPlan) -> Result> { @@ -121,14 +101,16 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; - use datafusion_expr::test::function_stub::sum; use datafusion_expr::{ - col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame, - WindowFrameBound, WindowFrameUnits, + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, + out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; + use datafusion_functions_aggregate::count::count_udaf; use std::sync::Arc; + use datafusion_functions_aggregate::expr_fn::{count, sum}; + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_analyzed_plan_eq_display_indent( Arc::new(CountWildcardRule::new()), @@ -239,7 +221,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index e14ee763a3c0..e949e1921b97 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -432,14 +432,8 @@ fn agg_exprs_evaluation_result_on_empty_batch( Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( - 0, - )))) - } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } + AggregateFunctionDefinition::BuiltIn(_fun) => { + Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::UDF(fun) => { if fun.name() == "COUNT" { diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index cef226d67b6c..7a8dd7aac249 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -129,10 +129,12 @@ mod tests { use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - col, count, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, - Signature, TypeSignature, + col, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, }; + use datafusion_functions_aggregate::expr_fn::count; + use std::sync::Arc; #[derive(Debug)] diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index af51814c9686..11540d3e162e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -818,10 +818,11 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; + use datafusion_expr::AggregateExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, - col, count, + col, expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, @@ -830,6 +831,9 @@ mod tests { WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::count; + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -1886,16 +1890,10 @@ mod tests { #[test] fn aggregate_filter_pushdown() -> Result<()> { let table_scan = test_table_scan()?; - - let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("b")], - false, - Some(Box::new(col("c").gt(lit(42)))), - None, - None, - )); - + let aggr_with_filter = count_udaf() + .call(vec![col("b")]) + .filter(col("c").gt(lit(42))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index e738209eb4fd..d3d22eb53f39 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -362,11 +362,13 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; - use datafusion_expr::test::function_stub::{sum, sum_udaf}; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, - AggregateFunction, + lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; + use datafusion_functions_aggregate::sum::sum_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -679,14 +681,11 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - None, - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; @@ -725,19 +724,16 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - None, - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -748,19 +744,17 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index b3501cca9efa..f60bf6609005 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; @@ -323,7 +324,9 @@ fn test_sql(sql: &str) -> Result { let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); let statement = &ast[0]; - let context_provider = MyContextProvider::default().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -345,7 +348,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index aee7bca3b88f..75f2e12320bf 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; @@ -61,9 +61,6 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, _) => { - return internal_err!("Builtin Count will be removed"); - } (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, @@ -642,20 +639,6 @@ mod tests { Ok(()) } - #[test] - fn test_count_return_type() -> Result<()> { - let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = AggregateFunction::Count.return_type(&[DataType::Int8])?; - assert_eq!(DataType::Int64, observed); - - let observed = - AggregateFunction::Count.return_type(&[DataType::Decimal128(28, 13)])?; - assert_eq!(DataType::Int64, observed); - Ok(()) - } - #[test] fn test_avg_return_type() -> Result<()> { let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index b1897aa58e7d..aa8d0e55b68f 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -59,6 +59,7 @@ serde_json = { workspace = true, optional = true } [dev-dependencies] datafusion-functions = { workspace = true, default-features = true } +datafusion-functions-aggregate = { workspace = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2bb3ec793d7f..31cb0d1da9d5 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -476,7 +476,7 @@ enum AggregateFunction { MAX = 1; // SUM = 2; AVG = 3; - COUNT = 4; + // COUNT = 4; // APPROX_DISTINCT = 5; ARRAY_AGG = 6; // VARIANCE = 7; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 59b7861a6ef1..503f83af65f2 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,7 +535,6 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::Avg => "AVG", - Self::Count => "COUNT", Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", @@ -571,7 +570,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "AVG", - "COUNT", "ARRAY_AGG", "CORRELATION", "APPROX_PERCENTILE_CONT", @@ -636,7 +634,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "AVG" => Ok(AggregateFunction::Avg), - "COUNT" => Ok(AggregateFunction::Count), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0861c287fcfa..2c0ea62466b4 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1930,7 +1930,7 @@ pub enum AggregateFunction { Max = 1, /// SUM = 2; Avg = 3, - Count = 4, + /// COUNT = 4; /// APPROX_DISTINCT = 5; ArrayAgg = 6, /// VARIANCE = 7; @@ -1972,7 +1972,6 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::Avg => "AVG", - AggregateFunction::Count => "COUNT", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", @@ -2004,7 +2003,6 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "AVG" => Some(Self::Avg), - "COUNT" => Some(Self::Count), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2ad40d883fe6..54a59485c836 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,7 +145,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::BitXor => Self::BitXor, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, - protobuf::AggregateFunction::Count => Self::Count, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6a275ed7a1b8..80ce05d151ee 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -116,7 +116,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::BitXor => Self::BitXor, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, - AggregateFunction::Count => Self::Count, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::RegrSlope => Self::RegrSlope, @@ -406,7 +405,6 @@ pub fn serialize_expr( AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d9736da69d42..d0f1c4aade5e 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,7 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion_functions_aggregate::count::count_udaf; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -35,8 +36,8 @@ use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::approx_median::approx_median; use datafusion::functions_aggregate::expr_fn::{ - covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, var_pop, - var_sample, + count, count_distinct, covar_pop, covar_samp, first_value, median, stddev, + stddev_pop, sum, var_pop, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -53,10 +54,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, ColumnarValue, ExprSchemable, LogicalPlan, Operator, - PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, + TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1782,28 +1783,18 @@ fn roundtrip_similar_to() { #[test] fn roundtrip_count() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - false, - None, - None, - None, - )); + let test_expr = count(col("bananas")); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } #[test] fn roundtrip_count_distinct() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - true, - None, - None, - None, - )); + let test_expr = count_udaf() + .call(vec![col("bananas")]) + .distinct() + .build() + .unwrap(); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 893db018c8af..aee4cf5a38ed 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -18,11 +18,12 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; -use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_sql::{ planner::{ContextProvider, SqlToRel}, sqlparser::{dialect::GenericDialect, parser::Parser}, @@ -50,7 +51,9 @@ fn main() { let statement = &ast[0]; // create a logical query plan - let context_provider = MyContextProvider::new().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::new() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -66,7 +69,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index dc25a6c33ece..12c48054f1a7 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -960,13 +960,14 @@ mod tests { use arrow_schema::DataType::Int8; use datafusion_common::TableReference; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - case, col, cube, exists, - expr::{AggregateFunction, AggregateFunctionDefinition}, - grouping_set, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, - try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, WindowFrame, WindowFunctionDefinition, + case, col, cube, exists, grouping_set, lit, not, not_exists, out_ref_col, + placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, + WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; use crate::unparser::dialect::CustomDialect; @@ -1127,29 +1128,19 @@ mod tests { ), (sum(col("a")), r#"sum(a)"#), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: true, - filter: None, - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .distinct() + .build() + .unwrap(), "COUNT(DISTINCT *)", ), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: false, - filter: Some(Box::new(lit(true))), - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .filter(lit(true)) + .build() + .unwrap(), "COUNT(*) FILTER (WHERE true)", ), ( @@ -1167,9 +1158,7 @@ mod tests { ), ( Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::AggregateFunction( - datafusion_expr::AggregateFunction::Count, - ), + fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], order_by: vec![Expr::Sort(Sort::new( diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 51bacb5f702b..bc27d25cf216 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -350,7 +350,8 @@ mod tests { use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; use arrow_schema::Fields; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{col, count, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_expr::{col, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_functions_aggregate::expr_fn::count; use crate::utils::{recursive_transform_unnest, resolve_positions_to_exprs}; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 72018371a5f1..33e28e7056b9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -19,7 +19,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; -use datafusion_expr::test::function_stub::sum_udaf; +use datafusion_expr::test::function_stub::{count_udaf, sum_udaf}; use datafusion_expr::{col, table_scan}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -153,7 +153,9 @@ fn roundtrip_statement() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default().with_udaf(sum_udaf()); + let context = MockContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index d91c09ae1287..893678d6b374 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -46,7 +46,8 @@ impl MockContextProvider { } pub(crate) fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 7b9d39a2b51e..8eb2a2b609e7 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,7 +37,9 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use datafusion_functions_aggregate::approx_median::approx_median_udaf; +use datafusion_functions_aggregate::{ + approx_median::approx_median_udaf, count::count_udaf, +}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -2702,7 +2704,8 @@ fn logical_plan_with_dialect_and_options( )) .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)) .with_udaf(sum_udaf()) - .with_udaf(approx_median_udaf()); + .with_udaf(approx_median_udaf()) + .with_udaf(count_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index f04d76822124..df6295d63b81 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -487,7 +487,7 @@ statement error Did you mean 'to_timestamp_seconds'? SELECT to_TIMESTAMPS_second(v2) from test; # Aggregate function -statement error Did you mean 'COUNT'? +query error DataFusion error: Error during planning: Invalid function 'counter' SELECT counter(*) from test; # Aggregate function