From b543f4375c03439e2f1896d9a371cf0da96b1a8d Mon Sep 17 00:00:00 2001 From: Daniel Heres Date: Sun, 30 May 2021 12:17:45 +0200 Subject: [PATCH 1/3] Make AggregateFunction take a single argument --- .../core/src/serde/logical_plan/from_proto.rs | 2 +- .../core/src/serde/logical_plan/to_proto.rs | 4 +- .../src/serde/physical_plan/from_proto.rs | 4 +- datafusion/src/logical_plan/expr.rs | 39 ++++++++----------- datafusion/src/optimizer/utils.rs | 4 +- datafusion/src/physical_plan/planner.rs | 11 ++---- datafusion/src/sql/planner.rs | 2 +- datafusion/src/sql/utils.rs | 7 +--- 8 files changed, 30 insertions(+), 43 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 020858fbfc3f..654386666488 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -976,7 +976,7 @@ impl TryInto for &protobuf::LogicalExprNode { Ok(Expr::AggregateFunction { fun, - args: vec![parse_required_expr(&expr.expr)?], + arg: Box::new(parse_required_expr(&expr.expr)?), distinct: false, //TODO }) } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 47e27483ff30..8ca552b1e9b9 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1062,7 +1062,7 @@ impl TryInto for &Expr { }) } Expr::AggregateFunction { - ref fun, ref args, .. + ref fun, arg: ref args, .. } => { let aggr_function = match fun { AggregateFunction::Min => protobuf::AggregateFunction::Min, @@ -1072,7 +1072,7 @@ impl TryInto for &Expr { AggregateFunction::Count => protobuf::AggregateFunction::Count, }; - let arg = &args[0]; + let arg = &**args; let aggregate_expr = Box::new(protobuf::AggregateExprNode { aggr_function: aggr_function.into(), expr: Some(Box::new(arg.try_into()?)), diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index d034f3ca3bfe..ced8a8d32208 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -325,10 +325,10 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let df_planner = DefaultPhysicalPlanner::default(); for (expr, name) in &logical_agg_expr { match expr { - Expr::AggregateFunction { fun, args, .. } => { + Expr::AggregateFunction { fun, arg, .. } => { let arg = df_planner .create_physical_expr( - &args[0], + &**arg, &physical_schema, &ctx_state, ) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 29723e73d25c..5326b192b680 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -187,7 +187,7 @@ pub enum Expr { /// Name of the function fun: aggregates::AggregateFunction, /// List of expressions to feed to the functions as arguments - args: Vec, + arg: Box, /// Whether this is a DISTINCT aggregation or not distinct: bool, }, @@ -259,12 +259,9 @@ impl Expr { .collect::>>()?; window_functions::return_type(fun, &data_types) } - Expr::AggregateFunction { fun, args, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - aggregates::return_type(fun, &data_types) + Expr::AggregateFunction { fun, arg: args, .. } => { + let data_type = args.get_type(schema)?; + aggregates::return_type(fun, &[data_type]) } Expr::AggregateUDF { fun, args, .. } => { let data_types = args @@ -590,9 +587,7 @@ impl Expr { Expr::WindowFunction { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::AggregateFunction { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::AggregateFunction { arg: args, .. } => args.accept(visitor), Expr::AggregateUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), @@ -728,11 +723,11 @@ impl Expr { fun, }, Expr::AggregateFunction { - args, + arg: args, fun, distinct, } => Expr::AggregateFunction { - args: rewrite_vec(args, rewriter)?, + arg: rewrite_boxed(args, rewriter)?, fun, distinct, }, @@ -969,7 +964,7 @@ pub fn min(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Min, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -978,7 +973,7 @@ pub fn max(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Max, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -987,7 +982,7 @@ pub fn sum(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Sum, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -996,7 +991,7 @@ pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Avg, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -1005,7 +1000,7 @@ pub fn count(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Count, distinct: false, - args: vec![expr], + arg: Box::new(expr), } } @@ -1014,7 +1009,7 @@ pub fn count_distinct(expr: Expr) -> Expr { Expr::AggregateFunction { fun: aggregates::AggregateFunction::Count, distinct: true, - args: vec![expr], + arg: Box::new(expr), } } @@ -1276,9 +1271,9 @@ impl fmt::Debug for Expr { Expr::AggregateFunction { fun, distinct, - ref args, + arg: ref args, .. - } => fmt_function(f, &fun.to_string(), *distinct, args), + } => fmt_function(f, &fun.to_string(), *distinct, &[*args.clone()]), Expr::AggregateUDF { fun, ref args, .. } => { fmt_function(f, &fun.name, false, args) } @@ -1394,9 +1389,9 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::AggregateFunction { fun, distinct, - args, + arg: args, .. - } => create_function_name(&fun.to_string(), *distinct, args, input_schema), + } => create_function_name(&fun.to_string(), *distinct, &[*args.clone()], input_schema), Expr::AggregateUDF { fun, args } => { let mut names = Vec::with_capacity(args.len()); for e in args { diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 284ead252ac6..5c503ce344ba 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -266,7 +266,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::ScalarFunction { args, .. } => Ok(args.clone()), Expr::ScalarUDF { args, .. } => Ok(args.clone()), Expr::WindowFunction { args, .. } => Ok(args.clone()), - Expr::AggregateFunction { args, .. } => Ok(args.clone()), + Expr::AggregateFunction { arg: args, .. } => Ok(vec![args.as_ref().to_owned()]), Expr::AggregateUDF { args, .. } => Ok(args.clone()), Expr::Case { expr, @@ -344,7 +344,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { }), Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { fun: fun.clone(), - args: expressions.to_vec(), + arg: Box::new(expressions[0].clone()), distinct: *distinct, }), Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF { diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 7ddfaf8f6897..a41fc2328196 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -779,19 +779,14 @@ impl DefaultPhysicalPlanner { Expr::AggregateFunction { fun, distinct, - args, + arg: args, .. } => { - let args = args - .iter() - .map(|e| { - self.create_physical_expr(e, physical_input_schema, ctx_state) - }) - .collect::>>()?; + let args = self.create_physical_expr(args, physical_input_schema, ctx_state)?; aggregates::create_aggregate_expr( fun, *distinct, - &args, + &[args], physical_input_schema, name, ) diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index a3027e589985..9c41dc0292f5 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1136,7 +1136,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(Expr::AggregateFunction { fun, distinct: function.distinct, - args, + arg: Box::new(args[0].clone()), }); }; diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 70b9df060839..c3012113c0ee 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -215,14 +215,11 @@ where None => match expr { Expr::AggregateFunction { fun, - args, + arg: args, distinct, } => Ok(Expr::AggregateFunction { fun: fun.clone(), - args: args - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, + arg: Box::new(clone_with_replacement(args, replacement_fn)?), distinct: *distinct, }), Expr::WindowFunction { fun, args } => Ok(Expr::WindowFunction { From 29872397a7eac485c52158ad097762c08d18678c Mon Sep 17 00:00:00 2001 From: Daniel Heres Date: Sun, 30 May 2021 12:24:47 +0200 Subject: [PATCH 2/3] Remove args bindings --- datafusion/src/logical_plan/expr.rs | 14 +++++++------- datafusion/src/optimizer/utils.rs | 2 +- datafusion/src/physical_plan/planner.rs | 4 ++-- datafusion/src/sql/utils.rs | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 5326b192b680..c742f5d9a16f 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -259,8 +259,8 @@ impl Expr { .collect::>>()?; window_functions::return_type(fun, &data_types) } - Expr::AggregateFunction { fun, arg: args, .. } => { - let data_type = args.get_type(schema)?; + Expr::AggregateFunction { fun, arg, .. } => { + let data_type = arg.get_type(schema)?; aggregates::return_type(fun, &[data_type]) } Expr::AggregateUDF { fun, args, .. } => { @@ -587,7 +587,7 @@ impl Expr { Expr::WindowFunction { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::AggregateFunction { arg: args, .. } => args.accept(visitor), + Expr::AggregateFunction { arg, .. } => arg.accept(visitor), Expr::AggregateUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), @@ -723,11 +723,11 @@ impl Expr { fun, }, Expr::AggregateFunction { - arg: args, + arg, fun, distinct, } => Expr::AggregateFunction { - arg: rewrite_boxed(args, rewriter)?, + arg: rewrite_boxed(arg, rewriter)?, fun, distinct, }, @@ -1389,9 +1389,9 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::AggregateFunction { fun, distinct, - arg: args, + arg, .. - } => create_function_name(&fun.to_string(), *distinct, &[*args.clone()], input_schema), + } => create_function_name(&fun.to_string(), *distinct, &[*arg.clone()], input_schema), Expr::AggregateUDF { fun, args } => { let mut names = Vec::with_capacity(args.len()); for e in args { diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 5c503ce344ba..ae315090fee9 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -266,7 +266,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::ScalarFunction { args, .. } => Ok(args.clone()), Expr::ScalarUDF { args, .. } => Ok(args.clone()), Expr::WindowFunction { args, .. } => Ok(args.clone()), - Expr::AggregateFunction { arg: args, .. } => Ok(vec![args.as_ref().to_owned()]), + Expr::AggregateFunction { arg, .. } => Ok(vec![arg.as_ref().to_owned()]), Expr::AggregateUDF { args, .. } => Ok(args.clone()), Expr::Case { expr, diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index a41fc2328196..49863ec78f73 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -779,10 +779,10 @@ impl DefaultPhysicalPlanner { Expr::AggregateFunction { fun, distinct, - arg: args, + arg, .. } => { - let args = self.create_physical_expr(args, physical_input_schema, ctx_state)?; + let args = self.create_physical_expr(arg, physical_input_schema, ctx_state)?; aggregates::create_aggregate_expr( fun, *distinct, diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index c3012113c0ee..84df5bdf87c5 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -215,11 +215,11 @@ where None => match expr { Expr::AggregateFunction { fun, - arg: args, + arg, distinct, } => Ok(Expr::AggregateFunction { fun: fun.clone(), - arg: Box::new(clone_with_replacement(args, replacement_fn)?), + arg: Box::new(clone_with_replacement(arg, replacement_fn)?), distinct: *distinct, }), Expr::WindowFunction { fun, args } => Ok(Expr::WindowFunction { From 16573acd14ab377f9868d44a362ee7a520e2eac8 Mon Sep 17 00:00:00 2001 From: Daniel Heres Date: Sun, 30 May 2021 12:28:07 +0200 Subject: [PATCH 3/3] Remove args bindings --- ballista/rust/core/src/serde/logical_plan/to_proto.rs | 4 ++-- datafusion/src/logical_plan/expr.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 8ca552b1e9b9..0d7b1e400978 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1062,7 +1062,7 @@ impl TryInto for &Expr { }) } Expr::AggregateFunction { - ref fun, arg: ref args, .. + ref fun, ref arg, .. } => { let aggr_function = match fun { AggregateFunction::Min => protobuf::AggregateFunction::Min, @@ -1072,7 +1072,7 @@ impl TryInto for &Expr { AggregateFunction::Count => protobuf::AggregateFunction::Count, }; - let arg = &**args; + let arg = &**arg; let aggregate_expr = Box::new(protobuf::AggregateExprNode { aggr_function: aggr_function.into(), expr: Some(Box::new(arg.try_into()?)), diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index c742f5d9a16f..46c8bf84d1d0 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1271,9 +1271,9 @@ impl fmt::Debug for Expr { Expr::AggregateFunction { fun, distinct, - arg: ref args, + ref arg, .. - } => fmt_function(f, &fun.to_string(), *distinct, &[*args.clone()]), + } => fmt_function(f, &fun.to_string(), *distinct, &[*arg.clone()]), Expr::AggregateUDF { fun, ref args, .. } => { fmt_function(f, &fun.name, false, args) }