diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index dc43cbaf10d0..708590e97ac3 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -2274,15 +2274,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { normalize_ident(function.name.0[0].clone()) }; - // first, check SQL reserved words - if name == "rollup" { - let args = self.function_args_to_expr(function.args, schema)?; - return Ok(Expr::GroupingSet(GroupingSet::Rollup(args))); - } else if name == "cube" { - let args = self.function_args_to_expr(function.args, schema)?; - return Ok(Expr::GroupingSet(GroupingSet::Cube(args))); - } - // next, scalar built-in if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { let args = self.function_args_to_expr(function.args, schema)?; @@ -2387,6 +2378,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + SQLExpr::Rollup(exprs) => { + let args: Result> = exprs.into_iter().map(|v| { + if v.len() != 1 { + Err(DataFusionError::Internal("Tuple expressions are not supported for Rollup expressions".to_string())) + } else { + self.sql_expr_to_logical_expr(v[0].clone(), schema, planner_context) + } + }).collect(); + Ok(Expr::GroupingSet(GroupingSet::Rollup(args?))) + } + + SQLExpr::Cube(exprs) => { + let args: Result> = exprs.into_iter().map(|v| { + if v.len() != 1 { + Err(DataFusionError::Internal("Tuple expressions not are supported for Cube expressions".to_string())) + } else { + self.sql_expr_to_logical_expr(v[0].clone(), schema, planner_context) + } + }).collect(); + Ok(Expr::GroupingSet(GroupingSet::Cube(args?))) + } + + SQLExpr::GroupingSets(exprs) => { + let args: Result>> = exprs.into_iter().map(|v| { + v.into_iter().map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)).collect() + }).collect(); + Ok(Expr::GroupingSet(GroupingSet::GroupingSets(args?))) + } + SQLExpr::Floor { expr, field: _field } => { let fun = BuiltinScalarFunction::Floor; let args = vec![self.sql_expr_to_logical_expr(*expr, schema, planner_context)?]; @@ -5829,11 +5849,12 @@ mod tests { quick_test(sql, expected); } - #[ignore] // see https://github.com/apache/arrow-datafusion/issues/2469 #[test] fn aggregate_with_grouping_sets() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; - let expected = "TBD"; + let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[person.id, GROUPING SETS ((person.state), (person.state, person.age), (person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person"; quick_test(sql, expected); }