diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 6a0df61ac182..e8cbde058566 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -22,11 +22,19 @@ use sqlparser::keywords::ALL_KEYWORDS; /// /// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`) /// but this behavior can be overridden as needed -/// Note: this trait will eventually be replaced by the Dialect in the SQLparser package +/// +/// **Note**: This trait will eventually be replaced by the Dialect in the SQLparser package /// /// See +/// See also the discussion in pub trait Dialect { + /// Return the character used to quote identifiers. fn identifier_quote_style(&self, _identifier: &str) -> Option; + + /// Does the dialect support specifying `NULLS FIRST/LAST` in `ORDER BY` clauses? + fn supports_nulls_first_in_sort(&self) -> bool { + true + } } pub struct DefaultDialect {} @@ -57,6 +65,10 @@ impl Dialect for MySqlDialect { fn identifier_quote_style(&self, _: &str) -> Option { Some('`') } + + fn supports_nulls_first_in_sort(&self) -> bool { + false + } } pub struct SqliteDialect {} diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 5fe744e359a6..ea991102df36 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -474,10 +474,17 @@ impl Unparser<'_> { nulls_first, }) => { let sql_parser_expr = self.expr_to_sql(expr)?; + + let nulls_first = if self.dialect.supports_nulls_first_in_sort() { + Some(*nulls_first) + } else { + None + }; + Ok(Unparsed::OrderByExpr(ast::OrderByExpr { expr: sql_parser_expr, asc: Some(*asc), - nulls_first: Some(*nulls_first), + nulls_first, })) } _ => { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 07a8d7817131..833ac5cdbe3a 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -439,10 +439,17 @@ impl Unparser<'_> { .map(|expr: &Expr| match expr { Expr::Sort(sort_expr) => { let col = self.expr_to_sql(&sort_expr.expr)?; + + let nulls_first = if self.dialect.supports_nulls_first_in_sort() { + Some(sort_expr.nulls_first) + } else { + None + }; + Ok(ast::OrderByExpr { asc: Some(sort_expr.asc), expr: col, - nulls_first: Some(sort_expr.nulls_first), + nulls_first, }) } _ => plan_err!("Expecting Sort expr"), diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e7da113e60d6..685911ee81d7 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -33,7 +33,11 @@ use datafusion_expr::{ Volatility, WindowUDF, }; use datafusion_functions::{string, unicode}; -use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; +use datafusion_sql::unparser::dialect::{ + DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect, + MySqlDialect as UnparserMySqlDialect, +}; +use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, @@ -4726,6 +4730,52 @@ fn roundtrip_crossjoin() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_statement_with_dialect() -> Result<()> { + struct TestStatementWithDialect { + sql: &'static str, + expected: &'static str, + parser_dialect: Box, + unparser_dialect: Box, + } + let tests: Vec = vec![ + TestStatementWithDialect { + sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", + expected: + "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", + expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST LIMIT 10"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + ]; + + for query in tests { + let statement = Parser::new(&*query.parser_dialect) + .try_with_sql(query.sql)? + .parse_statement()?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let unparser = Unparser::new(&*query.unparser_dialect); + let roundtrip_statement = unparser.plan_to_sql(&plan)?; + + let actual = format!("{}", &roundtrip_statement); + println!("roundtrip sql: {actual}"); + println!("plan {}", plan.display_indent()); + + assert_eq!(query.expected, actual); + } + + Ok(()) +} + #[test] fn test_unnest_logical_plan() -> Result<()> { let query = "select unnest(struct_col), unnest(array_col), struct_col, array_col from unnest_table";