From cc452a1f8e6ce989c14d3ddbf9328e7b7c1e04f2 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Sun, 25 Dec 2022 21:35:51 +1100 Subject: [PATCH 1/3] Sql planner support for rollup/cube/grouping sets ast nodes --- datafusion/sql/src/planner.rs | 78 ++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index dc43cbaf10d0..f076a83cd1ae 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -2275,6 +2275,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; // first, check SQL reserved words + // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/771 if name == "rollup" { let args = self.function_args_to_expr(function.args, schema)?; return Ok(Expr::GroupingSet(GroupingSet::Rollup(args))); @@ -2387,6 +2388,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + SQLExpr::Rollup(exprs) => { + let args: Result> = exprs.into_iter().map(|v| { + if v.len() != 1 { + Err(DataFusionError::Internal("Tuple expressions are not supported for Rollup expressions".to_string())) + } else { + self.sql_expr_to_logical_expr(v[0].clone(), schema, planner_context) + } + }).collect(); + Ok(Expr::GroupingSet(GroupingSet::Rollup(args?))) + } + + SQLExpr::Cube(exprs) => { + let args: Result> = exprs.into_iter().map(|v| { + if v.len() != 1 { + Err(DataFusionError::Internal("Tuple expressions not are supported for Cube expressions".to_string())) + } else { + self.sql_expr_to_logical_expr(v[0].clone(), schema, planner_context) + } + }).collect(); + Ok(Expr::GroupingSet(GroupingSet::Cube(args?))) + } + + SQLExpr::GroupingSets(exprs) => { + let args: Result>> = exprs.into_iter().map(|v| { + v.into_iter().map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)).collect() + }).collect(); + Ok(Expr::GroupingSet(GroupingSet::GroupingSets(args?))) + } + SQLExpr::Floor { expr, field: _field } => { let fun = BuiltinScalarFunction::Floor; let args = vec![self.sql_expr_to_logical_expr(*expr, schema, planner_context)?]; @@ -3264,7 +3294,9 @@ fn ensure_any_column_reference_is_unambiguous( mod tests { use std::any::Any; - use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; + use sqlparser::dialect::{ + Dialect, GenericDialect, HiveDialect, MySqlDialect, PostgreSqlDialect, + }; use datafusion_common::assert_contains; @@ -5833,10 +5865,52 @@ mod tests { #[test] fn aggregate_with_grouping_sets() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; - let expected = "TBD"; + let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[person.id, GROUPING SETS ((person.state), (person.state, person.age), (person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person"; quick_test(sql, expected); } + #[test] + // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/771 + fn postgres_aggregate_with_grouping_sets() -> Result<()> { + let dialect = &PostgreSqlDialect {}; + let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; + let plan = logical_plan_with_dialect(sql, dialect)?; + let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[person.id, GROUPING SETS ((person.state), (person.state, person.age), (person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person".to_string(); + assert_eq!(plan.display_indent().to_string(), expected); + Ok(()) + } + + #[test] + // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/771 + fn postgres_aggregate_with_cube() -> Result<()> { + let dialect = &PostgreSqlDialect {}; + let sql = + "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; + let plan = logical_plan_with_dialect(sql, dialect)?; + let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[person.id, CUBE (person.state, person.age)]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person".to_string(); + assert_eq!(plan.display_indent().to_string(), expected); + Ok(()) + } + + #[test] + // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/771 + fn postgres_aggregate_with_rollup() -> Result<()> { + let dialect = &PostgreSqlDialect {}; + let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; + let plan = logical_plan_with_dialect(sql, dialect)?; + let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[person.id, ROLLUP (person.state, person.age)]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person"; + assert_eq!(plan.display_indent().to_string(), expected); + Ok(()) + } + #[test] fn join_on_disjunction_condition() { let sql = "SELECT id, order_id \ From e4d5717eff00293640ef40e96c55fbd973878b22 Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Mon, 26 Dec 2022 18:50:02 +1100 Subject: [PATCH 2/3] Trigger build From d8cf475008ff9f19bc60f36077bd16cbaedc1e7d Mon Sep 17 00:00:00 2001 From: Jefffrey <22608443+Jefffrey@users.noreply.github.com> Date: Fri, 30 Dec 2022 07:56:14 +1100 Subject: [PATCH 3/3] sqlparser update --- datafusion/sql/src/planner.rs | 55 +---------------------------------- 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index f076a83cd1ae..708590e97ac3 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -2274,16 +2274,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { normalize_ident(function.name.0[0].clone()) }; - // first, check SQL reserved words - // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/771 - if name == "rollup" { - let args = self.function_args_to_expr(function.args, schema)?; - return Ok(Expr::GroupingSet(GroupingSet::Rollup(args))); - } else if name == "cube" { - let args = self.function_args_to_expr(function.args, schema)?; - return Ok(Expr::GroupingSet(GroupingSet::Cube(args))); - } - // next, scalar built-in if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { let args = self.function_args_to_expr(function.args, schema)?; @@ -3294,9 +3284,7 @@ fn ensure_any_column_reference_is_unambiguous( mod tests { use std::any::Any; - use sqlparser::dialect::{ - Dialect, GenericDialect, HiveDialect, MySqlDialect, PostgreSqlDialect, - }; + use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use datafusion_common::assert_contains; @@ -5861,7 +5849,6 @@ mod tests { quick_test(sql, expected); } - #[ignore] // see https://github.com/apache/arrow-datafusion/issues/2469 #[test] fn aggregate_with_grouping_sets() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; @@ -5871,46 +5858,6 @@ mod tests { quick_test(sql, expected); } - #[test] - // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/771 - fn postgres_aggregate_with_grouping_sets() -> Result<()> { - let dialect = &PostgreSqlDialect {}; - let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; - let plan = logical_plan_with_dialect(sql, dialect)?; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[person.id, GROUPING SETS ((person.state), (person.state, person.age), (person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: person".to_string(); - assert_eq!(plan.display_indent().to_string(), expected); - Ok(()) - } - - #[test] - // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/771 - fn postgres_aggregate_with_cube() -> Result<()> { - let dialect = &PostgreSqlDialect {}; - let sql = - "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; - let plan = logical_plan_with_dialect(sql, dialect)?; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[person.id, CUBE (person.state, person.age)]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: person".to_string(); - assert_eq!(plan.display_indent().to_string(), expected); - Ok(()) - } - - #[test] - // TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/771 - fn postgres_aggregate_with_rollup() -> Result<()> { - let dialect = &PostgreSqlDialect {}; - let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; - let plan = logical_plan_with_dialect(sql, dialect)?; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[person.id, ROLLUP (person.state, person.age)]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: person"; - assert_eq!(plan.display_indent().to_string(), expected); - Ok(()) - } - #[test] fn join_on_disjunction_condition() { let sql = "SELECT id, order_id \