diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 020858fbfc3f..654386666488 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -976,7 +976,7 @@ impl TryInto for &protobuf::LogicalExprNode { Ok(Expr::AggregateFunction { fun, - args: vec![parse_required_expr(&expr.expr)?], + arg: Box::new(parse_required_expr(&expr.expr)?), distinct: false, //TODO }) } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 47e27483ff30..0d7b1e400978 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1062,7 +1062,7 @@ impl TryInto for &Expr { }) } Expr::AggregateFunction { - ref fun, ref args, .. + ref fun, ref arg, .. } => { let aggr_function = match fun { AggregateFunction::Min => protobuf::AggregateFunction::Min, @@ -1072,7 +1072,7 @@ impl TryInto for &Expr { AggregateFunction::Count => protobuf::AggregateFunction::Count, }; - let arg = &args[0]; + let arg = &**arg; let aggregate_expr = Box::new(protobuf::AggregateExprNode { aggr_function: aggr_function.into(), expr: Some(Box::new(arg.try_into()?)), diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index d034f3ca3bfe..ced8a8d32208 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -325,10 +325,10 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let df_planner = DefaultPhysicalPlanner::default(); for (expr, name) in &logical_agg_expr { match expr { - Expr::AggregateFunction { fun, args, .. } => { + Expr::AggregateFunction { fun, arg, .. } => { let arg = df_planner .create_physical_expr( - &args[0], + &**arg, &physical_schema, &ctx_state, ) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 29723e73d25c..46c8bf84d1d0 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -187,7 +187,7 @@ pub enum Expr { /// Name of the function fun: aggregates::AggregateFunction, /// List of expressions to feed to the functions as arguments - args: Vec, + arg: Box, /// Whether this is a DISTINCT aggregation or not distinct: bool, }, @@ -259,12 +259,9 @@ impl Expr { .collect::>>()?; window_functions::return_type(fun, &data_types) } - Expr::AggregateFunction { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - aggregates::return_type(fun, &data_types) + Expr::AggregateFunction { fun, arg, .. } => { + let data_type = arg.get_type(schema)?; + aggregates::return_type(fun, &[data_type]) } Expr::AggregateUDF { fun, args, .. } => { let data_types = args @@ -590,9 +587,7 @@ impl Expr { Expr::WindowFunction { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::AggregateFunction { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::AggregateFunction { arg, .. } => arg.accept(visitor), Expr::AggregateUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), @@ -728,11 +723,11 @@ impl Expr { fun, }, Expr::AggregateFunction { - args, + arg, fun, distinct, } => Expr::AggregateFunction { - args: rewrite_vec(args, rewriter)?, + arg: rewrite_boxed(arg, rewriter)?, fun, distinct, }, @@ -969,7 +964,7 @@ pub fn min(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Min, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -978,7 +973,7 @@ pub fn max(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Max, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -987,7 +982,7 @@ pub fn sum(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Sum, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -996,7 +991,7 @@ pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Avg, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -1005,7 +1000,7 @@ pub fn count(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Count, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -1014,7 +1009,7 @@ pub fn count_distinct(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Count, distinct: true, - args: vec![expr], + arg: Box::new(expr), } } @@ -1276,9 +1271,9 @@ impl fmt::Debug for Expr { Expr::AggregateFunction { fun, distinct, - ref args, + ref arg, .. - } => fmt_function(f, &fun.to_string(), *distinct, args), + } => fmt_function(f, &fun.to_string(), *distinct, &[*arg.clone()]), Expr::AggregateUDF { fun, ref args, .. } => { fmt_function(f, &fun.name, false, args) } @@ -1394,9 +1389,9 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::AggregateFunction { fun, distinct, - args, + arg, .. - } => create_function_name(&fun.to_string(), *distinct, args, input_schema), + } => create_function_name(&fun.to_string(), *distinct, &[*arg.clone()], input_schema), Expr::AggregateUDF { fun, args } => { let mut names = Vec::with_capacity(args.len()); for e in args { diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 284ead252ac6..ae315090fee9 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -266,7 +266,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::ScalarFunction { args, .. } => Ok(args.clone()), Expr::ScalarUDF { args, .. } => Ok(args.clone()), Expr::WindowFunction { args, .. } => Ok(args.clone()), - Expr::AggregateFunction { args, .. } => Ok(args.clone()), + Expr::AggregateFunction { arg, .. } => Ok(vec![arg.as_ref().to_owned()]), Expr::AggregateUDF { args, .. } => Ok(args.clone()), Expr::Case { expr, @@ -344,7 +344,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { }), Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { fun: fun.clone(), - args: expressions.to_vec(), + arg: Box::new(expressions[0].clone()), distinct: *distinct, }), Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF { diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 7ddfaf8f6897..49863ec78f73 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -779,19 +779,14 @@ impl DefaultPhysicalPlanner { Expr::AggregateFunction { fun, distinct, - args, + arg, .. } => { - let args = args - .iter() - .map(|e| { - self.create_physical_expr(e, physical_input_schema, ctx_state) - }) - .collect::>>()?; + let args = self.create_physical_expr(arg, physical_input_schema, ctx_state)?; aggregates::create_aggregate_expr( fun, *distinct, - &args, + &[args], physical_input_schema, name, ) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index a3027e589985..9c41dc0292f5 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1136,7 +1136,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(Expr::AggregateFunction { fun, distinct: function.distinct, - args, + arg: Box::new(args[0].clone()), }); }; diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 70b9df060839..84df5bdf87c5 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -215,14 +215,11 @@ where None => match expr { Expr::AggregateFunction { fun, - args, + arg, distinct, } => Ok(Expr::AggregateFunction { fun: fun.clone(), - args: args - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, + arg: Box::new(clone_with_replacement(arg, replacement_fn)?), distinct: *distinct, }), Expr::WindowFunction { fun, args } => Ok(Expr::WindowFunction {