diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 03f981f54a05..51832087cf18 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -62,8 +62,8 @@ use arrow::datatypes::{Schema, SchemaRef}; use async_trait::async_trait; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::{ - self, AggregateFunction, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, - Like, ScalarUDF, TryCast, WindowFunction, + self, AggregateFunction, AggregateUDF, Between, BinaryExpr, Cast, GetIndexedField, + GroupingSet, Like, ScalarUDF, TryCast, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -199,7 +199,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { args, .. }) => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF { fun, args, filter } => { + Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => { if filter.is_some() { return Err(DataFusionError::Execution( "aggregate expression with filter is not supported".to_string(), @@ -1666,7 +1666,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ); Ok((agg_expr?, filter)) } - Expr::AggregateUDF { fun, args, filter } => { + Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => { let args = args .iter() .map(|e| { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3674c1779deb..39046645a962 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -21,10 +21,10 @@ use crate::aggregate_function; use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; +use crate::udaf; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; use crate::window_function; -use crate::AggregateUDF; use crate::Operator; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -159,14 +159,7 @@ pub enum Expr { /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), /// aggregate function - AggregateUDF { - /// The function - fun: Arc, - /// List of expressions to feed to the functions as arguments - args: Vec, - /// Optional filter applied prior to aggregating - filter: Option>, - }, + AggregateUDF(AggregateUDF), /// Returns whether the list contains the expr value. InList { /// The expression to compare @@ -507,6 +500,27 @@ impl WindowFunction { } } +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct AggregateUDF { + /// The function + pub fun: Arc, + /// List of expressions to feed to the functions as arguments + pub args: Vec, + /// Optional filter + pub filter: Option>, +} + +impl AggregateUDF { + /// Create a new AggregateUDF expression + pub fn new( + fun: Arc, + args: Vec, + filter: Option>, + ) -> Self { + Self { fun, args, filter } + } +} + /// Grouping sets /// See /// for Postgres definition. @@ -988,12 +1002,12 @@ impl fmt::Debug for Expr { } Ok(()) } - Expr::AggregateUDF { + Expr::AggregateUDF(AggregateUDF { fun, ref args, filter, .. - } => { + }) => { fmt_function(f, &fun.name, false, args, false)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; @@ -1344,7 +1358,7 @@ fn create_name(e: &Expr) -> Result { Ok(name) } } - Expr::AggregateUDF { fun, args, filter } => { + Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => { let mut names = Vec::with_capacity(args.len()); for e in args { names.push(create_name(e)?); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index fdb8a34aba64..9b91f8521057 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, BinaryExpr, Cast, GetIndexedField, ScalarFunction, ScalarUDF, - Sort, TryCast, WindowFunction, + AggregateFunction, AggregateUDF, BinaryExpr, Cast, GetIndexedField, ScalarFunction, + ScalarUDF, Sort, TryCast, WindowFunction, }; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::get_result_type; @@ -123,7 +123,7 @@ impl ExprSchemable for Expr { .collect::>>()?; aggregate_function::return_type(fun, &data_types) } - Expr::AggregateUDF { fun, args, .. } => { + Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 5a3442394d8c..3f173400606e 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -18,8 +18,8 @@ //! Tree node implementation for logical expr use crate::expr::{ - AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, - Like, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction, + AggregateFunction, AggregateUDF, Between, BinaryExpr, Case, Cast, GetIndexedField, + GroupingSet, Like, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction, }; use crate::Expr; use datafusion_common::tree_node::VisitRecursion; @@ -97,7 +97,7 @@ impl TreeNode for Expr { expr_vec } Expr::AggregateFunction(AggregateFunction { args, filter, .. }) - | Expr::AggregateUDF { args, filter, .. } => { + | Expr::AggregateUDF ( AggregateUDF{args, filter, .. }) => { let mut expr_vec = args.clone(); if let Some(f) = filter { @@ -313,11 +313,13 @@ impl TreeNode for Expr { )) } }, - Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF { - args: transform_vec(args, &mut transform)?, - fun, - filter: transform_option_box(filter, &mut transform)?, - }, + Expr::AggregateUDF(AggregateUDF { args, fun, filter }) => { + Expr::AggregateUDF(AggregateUDF { + args: transform_vec(args, &mut transform)?, + fun, + filter: transform_option_box(filter, &mut transform)?, + }) + } Expr::InList { expr, list, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 0ecb5280a942..d681390d27cc 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -86,10 +86,10 @@ impl AggregateUDF { /// creates a logical expression with a call of the UDAF /// This utility allows using the UDAF without requiring access to the registry. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF { + Expr::AggregateUDF(crate::expr::AggregateUDF { fun: Arc::new(self.clone()), args, filter: None, - } + }) } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 0a411f828aaa..9a77eff77ffa 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -404,17 +404,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )); Ok(expr) } - Expr::AggregateUDF { fun, args, filter } => { + Expr::AggregateUDF(expr::AggregateUDF { fun, args, filter }) => { let new_expr = coerce_arguments_for_signature( args.as_slice(), &self.schema, &fun.signature, )?; - let expr = Expr::AggregateUDF { - fun, - args: new_expr, - filter, - }; + let expr = + Expr::AggregateUDF(expr::AggregateUDF::new(fun, new_expr, filter)); Ok(expr) } Expr::WindowFunction(WindowFunction { @@ -883,11 +880,11 @@ mod test { }), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateUDF { - fun: Arc::new(my_avg), - args: vec![lit(10i64)], - filter: None, - }; + let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + Arc::new(my_avg), + vec![lit(10i64)], + None, + )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) @@ -913,11 +910,11 @@ mod test { &accumulator, &state_type, ); - let udaf = Expr::AggregateUDF { - fun: Arc::new(my_avg), - args: vec![lit("10")], - filter: None, - }; + let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + Arc::new(my_avg), + vec![lit("10")], + None, + )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") .err() diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index ae8bb3293d7d..a2db93330918 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -21,7 +21,6 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{ RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, }; @@ -870,16 +869,18 @@ mod test { let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| unimplemented!()); let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); - let udf_agg = |inner: Expr| Expr::AggregateUDF { - fun: Arc::new(AggregateUDF::new( - "my_agg", - &Signature::exact(vec![DataType::UInt32], Volatility::Stable), - &return_type, - &accumulator, - &state_type, - )), - args: vec![inner], - filter: None, + let udf_agg = |inner: Expr| { + Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( + Arc::new(AggregateUDF::new( + "my_agg", + &Signature::exact(vec![DataType::UInt32], Volatility::Stable), + &return_type, + &accumulator, + &state_type, + )), + vec![inner], + None, + )) }; // test: common aggregates diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 65407bb555d2..bc1b52406300 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1377,16 +1377,14 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = registry.udaf(pb.fun_name.as_str())?; - Ok(Expr::AggregateUDF { - fun: agg_fn, - args: pb - .args + Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + agg_fn, + pb.args .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, - filter: parse_optional_expr(pb.filter.as_deref(), registry)? - .map(Box::new), - }) + parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), + ))) } ExprType::GroupingSet(GroupingSetNode { expr }) => { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 23a6766cc332..087f621cd443 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -2602,11 +2602,11 @@ mod roundtrip_tests { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateUDF { - fun: Arc::new(dummy_agg.clone()), - args: vec![lit(1.0_f64)], - filter: Some(Box::new(lit(true))), - }; + let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new( + Arc::new(dummy_agg.clone()), + vec![lit(1.0_f64)], + Some(Box::new(lit(true))), + )); let ctx = SessionContext::new(); ctx.register_udaf(dummy_agg); diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d35d214513cf..25fa4714bc2a 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -704,7 +704,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .collect::, Error>>()?, })), }, - Expr::AggregateUDF { fun, args, filter } => Self { + Expr::AggregateUDF(expr::AggregateUDF { fun, args, filter }) => Self { expr_type: Some(ExprType::AggregateUdfExpr(Box::new( protobuf::AggregateUdfExprNode { fun_name: fun.name.clone(), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index fedb7eaacd15..482cd5f0c400 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -128,11 +128,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { let args = self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::AggregateUDF { - fun: fm, - args, - filter: None, - }); + return Ok(Expr::AggregateUDF(expr::AggregateUDF::new(fm, args, None))); } // Special case arrow_cast (as its type is dependent on its argument value) diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index df8891093a22..9a52df7ce911 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -22,8 +22,8 @@ use sqlparser::ast::Ident; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{ - AggregateFunction, Between, BinaryExpr, Case, GetIndexedField, GroupingSet, Like, - ScalarFunction, ScalarUDF, WindowFunction, + AggregateFunction, AggregateUDF, Between, BinaryExpr, Case, GetIndexedField, + GroupingSet, Like, ScalarFunction, ScalarUDF, WindowFunction, }; use datafusion_expr::expr::{Cast, Sort}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; @@ -197,14 +197,15 @@ where .collect::>>()?, window_frame.clone(), ))), - Expr::AggregateUDF { fun, args, filter } => Ok(Expr::AggregateUDF { - fun: fun.clone(), - args: args - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - filter: filter.clone(), - }), + Expr::AggregateUDF(AggregateUDF { fun, args, filter }) => { + Ok(Expr::AggregateUDF(AggregateUDF::new( + fun.clone(), + args.iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + filter.clone(), + ))) + } Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias( Box::new(clone_with_replacement(nested_expr, replacement_fn)?), alias_name.clone(),