From 454f378287b47d7fe143ef6a2fc02e8cc1484aa1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 29 Aug 2022 10:45:38 -0600 Subject: [PATCH 1/3] Add Aggregate::try_new with validation checks --- datafusion/core/src/physical_plan/planner.rs | 10 ++++---- datafusion/expr/src/logical_plan/builder.rs | 8 +++--- datafusion/expr/src/logical_plan/plan.rs | 25 +++++++++++++++++++ datafusion/expr/src/utils.rs | 12 ++++----- .../optimizer/src/common_subexpr_eliminate.rs | 12 ++++----- .../optimizer/src/projection_push_down.rs | 12 ++++----- .../src/single_distinct_to_groupby.rs | 24 +++++++++--------- 7 files changed, 64 insertions(+), 39 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 658da6a0fef6..2bfc2a88496c 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -621,12 +621,12 @@ impl DefaultPhysicalPlanner { LogicalPlan::Distinct(Distinct {input}) => { // Convert distinct to groupby with no aggregations let group_expr = expand_wildcard(input.schema(), input)?; - let aggregate = LogicalPlan::Aggregate(Aggregate { - input: input.clone(), + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + input.clone(), group_expr, - aggr_expr: vec![], - schema: input.schema().clone() - } + vec![], + input.schema().clone() + )? ); Ok(self.create_initial_plan(&aggregate, session_state).await?) } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 9eb379142ea6..2409ab481e11 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -700,12 +700,12 @@ impl LogicalPlanBuilder { exprlist_to_fields(all_expr, &self.plan)?, self.plan.schema().metadata().clone(), )?; - Ok(Self::from(LogicalPlan::Aggregate(Aggregate { - input: Arc::new(self.plan.clone()), + Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(self.plan.clone()), group_expr, aggr_expr, - schema: DFSchemaRef::new(aggr_schema), - }))) + DFSchemaRef::new(aggr_schema), + )?))) } /// Create an expression to represent the explanation of the plan diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 2d5eb46804df..8e7e13154b88 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1314,6 +1314,31 @@ pub struct Aggregate { } impl Aggregate { + pub fn try_new( + input: Arc, + group_expr: Vec, + aggr_expr: Vec, + schema: DFSchemaRef, + ) -> datafusion_common::Result { + if group_expr.is_empty() && aggr_expr.is_empty() { + return Err(DataFusionError::Plan( + "Aggregate requires at least one grouping or aggregate expression" + .to_string(), + )); + } + if schema.fields().len() != group_expr.len() + aggr_expr.len() { + return Err(DataFusionError::Plan( + "Aggregate schema has wrong number of fields".to_string(), + )); + } + Ok(Self { + input, + group_expr, + aggr_expr, + schema, + }) + } + pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Aggregate> { match plan { LogicalPlan::Aggregate(it) => Ok(it), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 367c722d220a..4eff5b618351 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -389,12 +389,12 @@ pub fn from_plan( })), LogicalPlan::Aggregate(Aggregate { group_expr, schema, .. - }) => Ok(LogicalPlan::Aggregate(Aggregate { - group_expr: expr[0..group_expr.len()].to_vec(), - aggr_expr: expr[group_expr.len()..].to_vec(), - input: Arc::new(inputs[0].clone()), - schema: schema.clone(), - })), + }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(inputs[0].clone()), + expr[0..group_expr.len()].to_vec(), + expr[group_expr.len()..].to_vec(), + schema.clone(), + )?)), LogicalPlan::Sort(Sort { .. }) => Ok(LogicalPlan::Sort(Sort { expr: expr.to_vec(), input: Arc::new(inputs[0].clone()), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 8627b404dce8..cec2de32e772 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -189,12 +189,12 @@ fn optimize( let new_aggr_expr = new_expr.pop().unwrap(); let new_group_expr = new_expr.pop().unwrap(); - Ok(LogicalPlan::Aggregate(Aggregate { - input: Arc::new(new_input), - group_expr: new_group_expr, - aggr_expr: new_aggr_expr, - schema: schema.clone(), - })) + Ok(LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(new_input), + new_group_expr, + new_aggr_expr, + schema.clone(), + )?)) } LogicalPlan::Sort(Sort { expr, input }) => { let arrays = to_arrays(expr, input, &mut expr_set)?; diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs index aa3cdfb4252c..80cc1044df25 100644 --- a/datafusion/optimizer/src/projection_push_down.rs +++ b/datafusion/optimizer/src/projection_push_down.rs @@ -345,18 +345,18 @@ fn optimize_plan( schema.metadata().clone(), )?; - Ok(LogicalPlan::Aggregate(Aggregate { - group_expr: group_expr.clone(), - aggr_expr: new_aggr_expr, - input: Arc::new(optimize_plan( + Ok(LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(optimize_plan( _optimizer, input, &new_required_columns, true, _optimizer_config, )?), - schema: DFSchemaRef::new(new_schema), - })) + group_expr.clone(), + new_aggr_expr, + DFSchemaRef::new(new_schema), + )?)) } // scans: // * remove un-used columns from the scan projection diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index a4d6619f2a4f..df6c8f949ec3 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -102,12 +102,12 @@ fn optimize(plan: &LogicalPlan) -> Result { input.schema().metadata().clone(), ) .unwrap(); - let grouped_agg = LogicalPlan::Aggregate(Aggregate { - input: input.clone(), - group_expr: all_group_args, - aggr_expr: Vec::new(), - schema: Arc::new(grouped_schema.clone()), - }); + let grouped_agg = LogicalPlan::Aggregate(Aggregate::try_new( + input.clone(), + all_group_args, + Vec::new(), + Arc::new(grouped_schema.clone()), + )?); let grouped_agg = optimize_children(&grouped_agg); let final_agg_schema = Arc::new( DFSchema::new_with_metadata( @@ -134,12 +134,12 @@ fn optimize(plan: &LogicalPlan) -> Result { )); }); - let final_agg = LogicalPlan::Aggregate(Aggregate { - input: Arc::new(grouped_agg.unwrap()), - group_expr: group_expr.clone(), - aggr_expr: new_aggr_expr, - schema: final_agg_schema, - }); + let final_agg = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(grouped_agg.unwrap()), + group_expr.clone(), + new_aggr_expr, + final_agg_schema, + )?); Ok(LogicalPlan::Projection(Projection::try_new_with_schema( alias_expr, From 04b44e0ea7c0b6bc430c33826e6d0c3072fe7d2d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 29 Aug 2022 10:59:56 -0600 Subject: [PATCH 2/3] fix calculation of number of grouping expressions --- datafusion/expr/src/logical_plan/plan.rs | 5 +++-- datafusion/expr/src/utils.rs | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 8e7e13154b88..b9b8763167a7 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -17,7 +17,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; -use crate::utils::exprlist_to_fields; +use crate::utils::{exprlist_to_fields, grouping_set_expr_count}; use crate::{Expr, TableProviderFilterPushDown, TableSource}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{plan_err, Column, DFSchema, DFSchemaRef, DataFusionError}; @@ -1326,7 +1326,8 @@ impl Aggregate { .to_string(), )); } - if schema.fields().len() != group_expr.len() + aggr_expr.len() { + let group_expr_count = grouping_set_expr_count(&group_expr)?; + if schema.fields().len() != group_expr_count + aggr_expr.len() { return Err(DataFusionError::Plan( "Aggregate schema has wrong number of fields".to_string(), )); diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 4eff5b618351..1a764ce19608 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -45,6 +45,22 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result Ok(()) } +/// Count the number of distinct exprs in a list of group by expressions. If the +/// first element is a `GroupingSet` expression then it must be the only expr. +pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { + if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if group_expr.len() > 1 { + return Err(DataFusionError::Plan( + "Invalid group by expressions, GroupingSet must be the only expression" + .to_string(), + )); + } + Ok(grouping_set.distinct_expr().len()) + } else { + Ok(group_expr.len()) + } +} + /// Find all distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { From 698fe032f817c30bd32d5fd9147aad8559b83fe5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 31 Aug 2022 08:00:54 -0600 Subject: [PATCH 3/3] use suggested error message --- datafusion/expr/src/logical_plan/plan.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index b9b8763167a7..cec55bfc1b1e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1328,9 +1328,11 @@ impl Aggregate { } let group_expr_count = grouping_set_expr_count(&group_expr)?; if schema.fields().len() != group_expr_count + aggr_expr.len() { - return Err(DataFusionError::Plan( - "Aggregate schema has wrong number of fields".to_string(), - )); + return Err(DataFusionError::Plan(format!( + "Aggregate schema has wrong number of fields. Expected {} got {}", + group_expr_count + aggr_expr.len(), + schema.fields().len() + ))); } Ok(Self { input,