From 872f5bd0abf33db30e8313be9cb6a31a868df53e Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Thu, 11 Apr 2024 22:48:50 +0200 Subject: [PATCH 1/9] trailing_comma option per dialect and related fixes --- src/dialect/bigquery.rs | 4 ++++ src/dialect/duckdb.rs | 4 ++++ src/dialect/mod.rs | 8 ++++++++ src/dialect/snowflake.rs | 4 ++++ src/parser/mod.rs | 14 +++++++++++++- tests/sqlparser_common.rs | 30 ++++++++++++++++++++++++++++-- 6 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/dialect/bigquery.rs b/src/dialect/bigquery.rs index bcd27c3b5..83ac34829 100644 --- a/src/dialect/bigquery.rs +++ b/src/dialect/bigquery.rs @@ -22,6 +22,10 @@ impl Dialect for BigQueryDialect { ch == '`' } + fn supports_projection_trailing_commas(&self) -> bool { + true + } + fn is_identifier_start(&self, ch: char) -> bool { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' } diff --git a/src/dialect/duckdb.rs b/src/dialect/duckdb.rs index f08545b99..a52cd8fbe 100644 --- a/src/dialect/duckdb.rs +++ b/src/dialect/duckdb.rs @@ -18,6 +18,10 @@ pub struct DuckDbDialect; // In most cases the redshift dialect is identical to [`PostgresSqlDialect`]. impl Dialect for DuckDbDialect { + fn supports_trailing_commas(&self) -> bool { + true + } + fn is_identifier_start(&self, ch: char) -> bool { ch.is_alphabetic() || ch == '_' } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 2463121e7..b04f5bcd6 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -157,6 +157,14 @@ pub trait Dialect: Debug + Any { // return None to fall back to the default behavior None } + /// Does the dialect support trailing commas around the query? + fn supports_trailing_commas(&self) -> bool { + false + } + /// Does the dialect support trailing commas only in proejction list? + fn supports_projection_trailing_commas(&self) -> bool { + self.supports_trailing_commas() + } /// Dialect-specific infix parser override fn parse_infix( &self, diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 1d9d983e5..fa2a715a8 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -38,6 +38,10 @@ impl Dialect for SnowflakeDialect { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() || ch == '_' } + fn supports_projection_trailing_commas(&self) -> bool { + true + } + fn is_identifier_part(&self, ch: char) -> bool { ch.is_ascii_lowercase() || ch.is_ascii_uppercase() diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 5bae7a133..76b04a87a 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -292,7 +292,7 @@ impl<'a> Parser<'a> { index: 0, dialect, recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH), - options: ParserOptions::default(), + options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()), } } @@ -343,6 +343,7 @@ impl<'a> Parser<'a> { /// assert!(matches!(result, Ok(_))); /// # Ok(()) /// # } + /// /// ``` pub fn with_options(mut self, options: ParserOptions) -> Self { self.options = options; @@ -8530,6 +8531,9 @@ impl<'a> Parser<'a> { with_privileges_keyword: self.parse_keyword(Keyword::PRIVILEGES), } } else { + let old_value = self.options.trailing_commas; + self.options.trailing_commas = false; + let (actions, err): (Vec<_>, Vec<_>) = self .parse_comma_separated(Parser::parse_grant_permission)? .into_iter() @@ -8553,6 +8557,8 @@ impl<'a> Parser<'a> { }) .partition(Result::is_ok); + self.options.trailing_commas = old_value; + if !err.is_empty() { let errors: Vec = err.into_iter().filter_map(|x| x.err()).collect(); return Err(ParserError::ParserError(format!( @@ -9007,6 +9013,12 @@ impl<'a> Parser<'a> { Expr::Wildcard => Ok(SelectItem::Wildcard( self.parse_wildcard_additional_options()?, )), + Expr::Identifier(v) if v.value.to_lowercase() == "from" => { + parser_err!( + format!("Trailing comma not allowed in dialect: {:?}", self.dialect), + self.peek_token().location + ) + } expr => { let expr: Expr = if self.dialect.supports_filter_during_aggregation() && self.parse_keyword(Keyword::FILTER) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index c94bd3779..234dafe9c 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8394,9 +8394,11 @@ fn parse_non_latin_identifiers() { #[test] fn parse_trailing_comma() { + // At the moment, Duck DB is the only dialect that allows + // trailing commas anywhere in the query let trailing_commas = TestedDialects { - dialects: vec![Box::new(GenericDialect {})], - options: Some(ParserOptions::new().with_trailing_commas(true)), + dialects: vec![Box::new(DuckDbDialect {})], + options: None, }; trailing_commas.one_statement_parses_to( @@ -8421,6 +8423,30 @@ fn parse_trailing_comma() { trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track"); } +#[test] +fn parse_porjection_trailing_comma() { + // Some dialects allow trailing commas only in the projection + let trailing_commas = TestedDialects { + dialects: vec![Box::new(SnowflakeDialect {}), Box::new(BigQueryDialect {})], + options: None, + }; + + trailing_commas.one_statement_parses_to( + "SELECT album_id, name, FROM track", + "SELECT album_id, name FROM track", + ); + + trailing_commas.verified_stmt("SELECT album_id, name FROM track"); + + trailing_commas.verified_stmt("SELECT * FROM track ORDER BY milliseconds"); + + trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track"); + + assert!(trailing_commas + .parse_sql_statements("SELECT * FROM track ORDER BY milliseconds,") + .is_err(),) +} + #[test] fn parse_create_type() { let create_type = From 4c0799ac31b7c9ee464d525aad065682c3ba5eb7 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Thu, 11 Apr 2024 22:59:53 +0200 Subject: [PATCH 2/9] testing new error --- tests/sqlparser_common.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 234dafe9c..fe1697a47 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8421,6 +8421,16 @@ fn parse_trailing_comma() { trailing_commas.verified_stmt("SELECT * FROM track ORDER BY milliseconds"); trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track"); + + // doesn't allow any trailing commas + let trailing_commas = TestedDialects { + dialects: vec![Box::new(GenericDialect {})], + options: None, + }; + + assert!(trailing_commas + .parse_sql_statements("SELECT name, age, FROM employees;") + .is_err()); } #[test] From 9c9c041cbe63e8c038432b551102f8d1ddd2d5e2 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Thu, 11 Apr 2024 23:14:29 +0200 Subject: [PATCH 3/9] use new trait method --- src/dialect/mod.rs | 2 +- src/parser/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index b04f5bcd6..76e224848 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -161,7 +161,7 @@ pub trait Dialect: Debug + Any { fn supports_trailing_commas(&self) -> bool { false } - /// Does the dialect support trailing commas only in proejction list? + /// Does the dialect support trailing commas in proejction list? fn supports_projection_trailing_commas(&self) -> bool { self.supports_trailing_commas() } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 76b04a87a..971c33c89 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -3067,7 +3067,7 @@ impl<'a> Parser<'a> { // This pattern could be captured better with RAII type semantics, but it's quite a bit of // code to add for just one case, so we'll just do it manually here. let old_value = self.options.trailing_commas; - self.options.trailing_commas |= dialect_of!(self is BigQueryDialect | SnowflakeDialect); + self.options.trailing_commas |= self.dialect.supports_projection_trailing_commas(); let ret = self.parse_comma_separated(|p| p.parse_select_item()); self.options.trailing_commas = old_value; From 064e1973f53a2e9204db0e93096ca8a1c5e385e7 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Fri, 12 Apr 2024 00:08:15 +0200 Subject: [PATCH 4/9] typo --- src/dialect/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 76e224848..7374aa5dc 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -161,7 +161,7 @@ pub trait Dialect: Debug + Any { fn supports_trailing_commas(&self) -> bool { false } - /// Does the dialect support trailing commas in proejction list? + /// Does the dialect support trailing commas in the projection list? fn supports_projection_trailing_commas(&self) -> bool { self.supports_trailing_commas() } From db5573e6e2bb3567dbde976c1efef47ccef92d2d Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Fri, 12 Apr 2024 13:54:58 +0200 Subject: [PATCH 5/9] typo --- tests/sqlparser_common.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index fe1697a47..1bbe8ee82 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8434,7 +8434,7 @@ fn parse_trailing_comma() { } #[test] -fn parse_porjection_trailing_comma() { +fn parse_projection_trailing_comma() { // Some dialects allow trailing commas only in the projection let trailing_commas = TestedDialects { dialects: vec![Box::new(SnowflakeDialect {}), Box::new(BigQueryDialect {})], From 4949977a6f37ef8164a082e1be15341691d175a3 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Fri, 12 Apr 2024 15:44:24 +0200 Subject: [PATCH 6/9] remove trailing comma from create + better tests --- src/parser/mod.rs | 15 ++++++++---- tests/sqlparser_common.rs | 48 ++++++++++++++++++++++++++++++------- tests/sqlparser_postgres.rs | 2 +- 3 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 971c33c89..ef1abd591 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5015,12 +5015,19 @@ impl<'a> Parser<'a> { } else { return self.expected("column name or constraint definition", self.peek_token()); } + let comma = self.consume_token(&Token::Comma); - if self.consume_token(&Token::RParen) { - // allow a trailing comma, even though it's not in standard - break; - } else if !comma { + let rparen = self.peek_token().token == Token::RParen; + + if comma && rparen && !self.options.trailing_commas { + return self.expected("column definition after ','", self.peek_token()); + } else if !comma && !rparen { return self.expected("',' or ')' after column definition", self.peek_token()); + }; + + if rparen && (!comma || self.options.trailing_commas) { + let _ = self.consume_token(&Token::RParen); + break; } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 1bbe8ee82..4e7cafea7 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -3354,8 +3354,13 @@ fn parse_create_table_clone() { #[test] fn parse_create_table_trailing_comma() { - let sql = "CREATE TABLE foo (bar int,)"; - all_dialects().one_statement_parses_to(sql, "CREATE TABLE foo (bar INT)"); + let dialect = TestedDialects { + dialects: vec![Box::new(DuckDbDialect {})], + options: None, + }; + + let sql = "CREATE TABLE foo (bar int,);"; + dialect.one_statement_parses_to(sql, "CREATE TABLE foo (bar INT)"); } #[test] @@ -8416,6 +8421,11 @@ fn parse_trailing_comma() { "SELECT DISTINCT ON (album_id) name FROM track", ); + trailing_commas.one_statement_parses_to( + "CREATE TABLE employees (name text, age int,)", + "CREATE TABLE employees (name TEXT, age INT)", + ); + trailing_commas.verified_stmt("SELECT album_id, name FROM track"); trailing_commas.verified_stmt("SELECT * FROM track ORDER BY milliseconds"); @@ -8428,9 +8438,21 @@ fn parse_trailing_comma() { options: None, }; - assert!(trailing_commas - .parse_sql_statements("SELECT name, age, FROM employees;") - .is_err()); + assert_eq!( + trailing_commas + .parse_sql_statements("SELECT name, age, FROM employees;") + .unwrap_err(), + ParserError::ParserError( + "Trailing comma not allowed in dialect: GenericDialect".to_string() + ) + ); + + assert_eq!( + trailing_commas + .parse_sql_statements("CREATE TABLE employees (name text, age int,)") + .unwrap_err(), + ParserError::ParserError("Expected column definition after ',', found: )".to_string()) + ); } #[test] @@ -8452,9 +8474,19 @@ fn parse_projection_trailing_comma() { trailing_commas.verified_stmt("SELECT DISTINCT ON (album_id) name FROM track"); - assert!(trailing_commas - .parse_sql_statements("SELECT * FROM track ORDER BY milliseconds,") - .is_err(),) + assert_eq!( + trailing_commas + .parse_sql_statements("SELECT * FROM track ORDER BY milliseconds,") + .unwrap_err(), + ParserError::ParserError("Expected an expression:, found: EOF".to_string()) + ); + + assert_eq!( + trailing_commas + .parse_sql_statements("CREATE TABLE employees (name text, age int,)") + .unwrap_err(), + ParserError::ParserError("Expected column definition after ',', found: )".to_string()) + ); } #[test] diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index ea5c9875b..9b743fa4d 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -3572,7 +3572,7 @@ fn parse_create_table_with_alias() { int2_col INT2, float8_col FLOAT8, float4_col FLOAT4, - bool_col BOOL, + bool_col BOOL );"; match pg_and_generic().one_statement_parses_to(sql, "") { Statement::CreateTable { From d609d2d8e8dca46373cdc65808c1ccfba45b1658 Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Fri, 12 Apr 2024 17:01:36 +0200 Subject: [PATCH 7/9] change error message to not include dialect name --- examples/parse_select.rs | 2 +- src/parser/mod.rs | 3 +-- tests/sqlparser_common.rs | 6 ++---- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/parse_select.rs b/examples/parse_select.rs index 71fe1fa1e..8545eca23 100644 --- a/examples/parse_select.rs +++ b/examples/parse_select.rs @@ -16,7 +16,7 @@ use sqlparser::dialect::GenericDialect; use sqlparser::parser::*; fn main() { - let sql = "SELECT a, b, 123, myfunc(b) \ + let sql = "SELECT a, b, 123, myfunc(b), \ FROM table_1 \ WHERE a > b AND b < 100 \ ORDER BY a DESC, b"; diff --git a/src/parser/mod.rs b/src/parser/mod.rs index ef1abd591..aea73f7ec 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -343,7 +343,6 @@ impl<'a> Parser<'a> { /// assert!(matches!(result, Ok(_))); /// # Ok(()) /// # } - /// /// ``` pub fn with_options(mut self, options: ParserOptions) -> Self { self.options = options; @@ -9022,7 +9021,7 @@ impl<'a> Parser<'a> { )), Expr::Identifier(v) if v.value.to_lowercase() == "from" => { parser_err!( - format!("Trailing comma not allowed in dialect: {:?}", self.dialect), + format!("Expected an expression, found: {}", v), self.peek_token().location ) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 4e7cafea7..23b537edd 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8440,11 +8440,9 @@ fn parse_trailing_comma() { assert_eq!( trailing_commas - .parse_sql_statements("SELECT name, age, FROM employees;") + .parse_sql_statements("SELECT name, age, from employees;") .unwrap_err(), - ParserError::ParserError( - "Trailing comma not allowed in dialect: GenericDialect".to_string() - ) + ParserError::ParserError("Expected an expression, found: from".to_string()) ); assert_eq!( From 91fbcfb8aec2c3d042693e63a8795fe3ba5e45db Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Fri, 12 Apr 2024 18:39:19 +0200 Subject: [PATCH 8/9] better errors for create statements --- src/parser/mod.rs | 4 +--- tests/sqlparser_common.rs | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index aea73f7ec..f35b1f326 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5018,9 +5018,7 @@ impl<'a> Parser<'a> { let comma = self.consume_token(&Token::Comma); let rparen = self.peek_token().token == Token::RParen; - if comma && rparen && !self.options.trailing_commas { - return self.expected("column definition after ','", self.peek_token()); - } else if !comma && !rparen { + if !comma && !rparen { return self.expected("',' or ')' after column definition", self.peek_token()); }; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 23b537edd..7ab4b16e1 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -8449,7 +8449,9 @@ fn parse_trailing_comma() { trailing_commas .parse_sql_statements("CREATE TABLE employees (name text, age int,)") .unwrap_err(), - ParserError::ParserError("Expected column definition after ',', found: )".to_string()) + ParserError::ParserError( + "Expected column name or constraint definition, found: )".to_string() + ) ); } @@ -8483,7 +8485,9 @@ fn parse_projection_trailing_comma() { trailing_commas .parse_sql_statements("CREATE TABLE employees (name text, age int,)") .unwrap_err(), - ParserError::ParserError("Expected column definition after ',', found: )".to_string()) + ParserError::ParserError( + "Expected column name or constraint definition, found: )".to_string() + ), ); } From ba4da0faff2f454b2c452b1902ae5c58f3cf7c3e Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen Date: Fri, 12 Apr 2024 18:40:34 +0200 Subject: [PATCH 9/9] revert test change --- examples/parse_select.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/parse_select.rs b/examples/parse_select.rs index 8545eca23..71fe1fa1e 100644 --- a/examples/parse_select.rs +++ b/examples/parse_select.rs @@ -16,7 +16,7 @@ use sqlparser::dialect::GenericDialect; use sqlparser::parser::*; fn main() { - let sql = "SELECT a, b, 123, myfunc(b), \ + let sql = "SELECT a, b, 123, myfunc(b) \ FROM table_1 \ WHERE a > b AND b < 100 \ ORDER BY a DESC, b";