diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 747cd1a204b5..a67e2dac7c22 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -645,12 +645,12 @@ impl DefaultPhysicalPlanner { LogicalPlan::Distinct(Distinct {input}) => { // Convert distinct to groupby with no aggregations let group_expr = expand_wildcard(input.schema(), input)?; - let aggregate = LogicalPlan::Aggregate(Aggregate { - input: input.clone(), + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + input.clone(), group_expr, - aggr_expr: vec![], - schema: input.schema().clone() - } + vec![], + input.schema().clone() + )? ); Ok(self.create_initial_plan(&aggregate, session_state).await?) } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2946a74afd70..41ba95140060 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -701,12 +701,12 @@ impl LogicalPlanBuilder { exprlist_to_fields(all_expr, &self.plan)?, self.plan.schema().metadata().clone(), )?; - Ok(Self::from(LogicalPlan::Aggregate(Aggregate { - input: Arc::new(self.plan.clone()), + Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(self.plan.clone()), group_expr, aggr_expr, - schema: DFSchemaRef::new(aggr_schema), - }))) + DFSchemaRef::new(aggr_schema), + )?))) } /// Create an expression to represent the explanation of the plan diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 2d5eb46804df..cec55bfc1b1e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -17,7 +17,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; -use crate::utils::exprlist_to_fields; +use crate::utils::{exprlist_to_fields, grouping_set_expr_count}; use crate::{Expr, TableProviderFilterPushDown, TableSource}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{plan_err, Column, DFSchema, DFSchemaRef, DataFusionError}; @@ -1314,6 +1314,34 @@ pub struct Aggregate { } impl Aggregate { + pub fn try_new( + input: Arc, + group_expr: Vec, + aggr_expr: Vec, + schema: DFSchemaRef, + ) -> datafusion_common::Result { + if group_expr.is_empty() && aggr_expr.is_empty() { + return Err(DataFusionError::Plan( + "Aggregate requires at least one grouping or aggregate expression" + .to_string(), + )); + } + let group_expr_count = grouping_set_expr_count(&group_expr)?; + if schema.fields().len() != group_expr_count + aggr_expr.len() { + return Err(DataFusionError::Plan(format!( + "Aggregate schema has wrong number of fields. Expected {} got {}", + group_expr_count + aggr_expr.len(), + schema.fields().len() + ))); + } + Ok(Self { + input, + group_expr, + aggr_expr, + schema, + }) + } + pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Aggregate> { match plan { LogicalPlan::Aggregate(it) => Ok(it), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7d3f78b8fa31..e748536d735d 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -45,6 +45,22 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result Ok(()) } +/// Count the number of distinct exprs in a list of group by expressions. If the +/// first element is a `GroupingSet` expression then it must be the only expr. +pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { + if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if group_expr.len() > 1 { + return Err(DataFusionError::Plan( + "Invalid group by expressions, GroupingSet must be the only expression" + .to_string(), + )); + } + Ok(grouping_set.distinct_expr().len()) + } else { + Ok(group_expr.len()) + } +} + /// Find all distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { @@ -395,12 +411,12 @@ pub fn from_plan( })), LogicalPlan::Aggregate(Aggregate { group_expr, schema, .. - }) => Ok(LogicalPlan::Aggregate(Aggregate { - group_expr: expr[0..group_expr.len()].to_vec(), - aggr_expr: expr[group_expr.len()..].to_vec(), - input: Arc::new(inputs[0].clone()), - schema: schema.clone(), - })), + }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(inputs[0].clone()), + expr[0..group_expr.len()].to_vec(), + expr[group_expr.len()..].to_vec(), + schema.clone(), + )?)), LogicalPlan::Sort(Sort { .. }) => Ok(LogicalPlan::Sort(Sort { expr: expr.to_vec(), input: Arc::new(inputs[0].clone()), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 305283d9943d..f015aeaa0bca 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -189,12 +189,12 @@ fn optimize( let new_aggr_expr = new_expr.pop().unwrap(); let new_group_expr = new_expr.pop().unwrap(); - Ok(LogicalPlan::Aggregate(Aggregate { - input: Arc::new(new_input), - group_expr: new_group_expr, - aggr_expr: new_aggr_expr, - schema: schema.clone(), - })) + Ok(LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(new_input), + new_group_expr, + new_aggr_expr, + schema.clone(), + )?)) } LogicalPlan::Sort(Sort { expr, input }) => { let arrays = to_arrays(expr, input, &mut expr_set)?; diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs index aa3cdfb4252c..80cc1044df25 100644 --- a/datafusion/optimizer/src/projection_push_down.rs +++ b/datafusion/optimizer/src/projection_push_down.rs @@ -345,18 +345,18 @@ fn optimize_plan( schema.metadata().clone(), )?; - Ok(LogicalPlan::Aggregate(Aggregate { - group_expr: group_expr.clone(), - aggr_expr: new_aggr_expr, - input: Arc::new(optimize_plan( + Ok(LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(optimize_plan( _optimizer, input, &new_required_columns, true, _optimizer_config, )?), - schema: DFSchemaRef::new(new_schema), - })) + group_expr.clone(), + new_aggr_expr, + DFSchemaRef::new(new_schema), + )?)) } // scans: // * remove un-used columns from the scan projection diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 3244fac8d863..e36706caa226 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -100,12 +100,12 @@ fn optimize(plan: &LogicalPlan) -> Result { all_field, input.schema().metadata().clone(), )?; - let grouped_agg = LogicalPlan::Aggregate(Aggregate { - input: input.clone(), - group_expr: all_group_args, - aggr_expr: Vec::new(), - schema: Arc::new(grouped_schema.clone()), - }); + let grouped_agg = LogicalPlan::Aggregate(Aggregate::try_new( + input.clone(), + all_group_args, + Vec::new(), + Arc::new(grouped_schema.clone()), + )?); let grouped_agg = optimize_children(&grouped_agg); let final_agg_schema = Arc::new(DFSchema::new_with_metadata( base_group_expr @@ -129,13 +129,12 @@ fn optimize(plan: &LogicalPlan) -> Result { )); }); - let final_agg = LogicalPlan::Aggregate(Aggregate { - input: Arc::new(grouped_agg?), - group_expr: group_expr.clone(), - aggr_expr: new_aggr_expr, - schema: final_agg_schema, - }); - + let final_agg = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(grouped_agg?), + group_expr.clone(), + new_aggr_expr, + final_agg_schema, + )?); Ok(LogicalPlan::Projection(Projection::try_new_with_schema( alias_expr, Arc::new(final_agg),