diff --git a/src/ast/mod.rs b/src/ast/mod.rs index eaf99b31b..2f723f012 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -27,7 +27,7 @@ pub use self::ddl::{ pub use self::operator::{BinaryOperator, UnaryOperator}; pub use self::query::{ Cte, Fetch, Join, JoinConstraint, JoinOperator, OrderByExpr, Query, Select, SelectItem, - SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Values, + SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values, }; pub use self::value::{DateTimeField, Value}; diff --git a/src/ast/query.rs b/src/ast/query.rs index 656f7f14b..a5eea141f 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -114,6 +114,8 @@ impl fmt::Display for SetOperator { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Select { pub distinct: bool, + /// MSSQL syntax: `TOP () [ PERCENT ] [ WITH TIES ]` + pub top: Option, /// projection expressions pub projection: Vec, /// FROM @@ -128,12 +130,11 @@ pub struct Select { impl fmt::Display for Select { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "SELECT{} {}", - if self.distinct { " DISTINCT" } else { "" }, - display_comma_separated(&self.projection) - )?; + write!(f, "SELECT{}", if self.distinct { " DISTINCT" } else { "" })?; + if let Some(ref top) = self.top { + write!(f, " {}", top)?; + } + write!(f, " {}", display_comma_separated(&self.projection))?; if !self.from.is_empty() { write!(f, " FROM {}", display_comma_separated(&self.from))?; } @@ -408,6 +409,26 @@ impl fmt::Display for Fetch { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Top { + /// SQL semantic equivalent of LIMIT but with same structure as FETCH. + pub with_ties: bool, + pub percent: bool, + pub quantity: Option, +} + +impl fmt::Display for Top { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let extension = if self.with_ties { " WITH TIES" } else { "" }; + if let Some(ref quantity) = self.quantity { + let percent = if self.percent { " PERCENT" } else { "" }; + write!(f, "TOP ({}){}{}", quantity, percent, extension) + } else { + write!(f, "TOP{}", extension) + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Values(pub Vec>); diff --git a/src/dialect/keywords.rs b/src/dialect/keywords.rs index c083c0692..9795f2af3 100644 --- a/src/dialect/keywords.rs +++ b/src/dialect/keywords.rs @@ -374,6 +374,7 @@ define_keywords!( TIMEZONE_HOUR, TIMEZONE_MINUTE, TO, + TOP, TRAILING, TRANSACTION, TRANSLATE, @@ -426,7 +427,7 @@ define_keywords!( /// can be parsed unambiguously without looking ahead. pub const RESERVED_FOR_TABLE_ALIAS: &[&str] = &[ // Reserved as both a table and a column alias: - WITH, SELECT, WHERE, GROUP, HAVING, ORDER, LIMIT, OFFSET, FETCH, UNION, EXCEPT, INTERSECT, + WITH, SELECT, WHERE, GROUP, HAVING, ORDER, TOP, LIMIT, OFFSET, FETCH, UNION, EXCEPT, INTERSECT, // Reserved only as a table alias in the `FROM`/`JOIN` clauses: ON, JOIN, INNER, CROSS, FULL, LEFT, RIGHT, NATURAL, USING, // for MSSQL-specific OUTER APPLY (seems reserved in most dialects) diff --git a/src/parser.rs b/src/parser.rs index cbdcaba09..c9e32ed3b 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -783,7 +783,6 @@ impl Parser { } /// Bail out if the current token is not one of the expected keywords, or consume it if it is - #[must_use] pub fn expect_one_of_keywords( &mut self, keywords: &[&'static str], @@ -1561,6 +1560,13 @@ impl Parser { if all && distinct { return parser_err!("Cannot specify both ALL and DISTINCT in SELECT"); } + + let top = if self.parse_keyword("TOP") { + Some(self.parse_top()?) + } else { + None + }; + let projection = self.parse_comma_separated(Parser::parse_select_item)?; // Note that for keywords to be properly handled here, they need to be @@ -1594,6 +1600,7 @@ impl Parser { Ok(Select { distinct, + top, projection, from, selection, @@ -1940,6 +1947,28 @@ impl Parser { Ok(OrderByExpr { expr, asc }) } + /// Parse a TOP clause, MSSQL equivalent of LIMIT, + /// that follows after SELECT [DISTINCT]. + pub fn parse_top(&mut self) -> Result { + let quantity = if self.consume_token(&Token::LParen) { + let quantity = self.parse_expr()?; + self.expect_token(&Token::RParen)?; + Some(quantity) + } else { + Some(Expr::Value(self.parse_number_value()?)) + }; + + let percent = self.parse_keyword("PERCENT"); + + let with_ties = self.parse_keywords(vec!["WITH", "TIES"]); + + Ok(Top { + with_ties, + percent, + quantity, + }) + } + /// Parse a LIMIT clause pub fn parse_limit(&mut self) -> Result, ParserError> { if self.parse_keyword("ALL") { diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 62d534895..96c9535ea 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -522,6 +522,7 @@ fn peeking_take_while( #[cfg(test)] mod tests { use super::super::dialect::GenericDialect; + use super::super::dialect::MsSqlDialect; use super::*; #[test] @@ -782,6 +783,28 @@ mod tests { compare(expected, tokens); } + #[test] + fn tokenize_mssql_top() { + let sql = "SELECT TOP 5 [bar] FROM foo"; + let dialect = MsSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, sql); + let tokens = tokenizer.tokenize().unwrap(); + let expected = vec![ + Token::make_keyword("SELECT"), + Token::Whitespace(Whitespace::Space), + Token::make_keyword("TOP"), + Token::Whitespace(Whitespace::Space), + Token::Number(String::from("5")), + Token::Whitespace(Whitespace::Space), + Token::make_word("bar", Some('[')), + Token::Whitespace(Whitespace::Space), + Token::make_keyword("FROM"), + Token::Whitespace(Whitespace::Space), + Token::make_word("foo", None), + ]; + compare(expected, tokens); + } + fn compare(expected: Vec, actual: Vec) { //println!("------------------------------"); //println!("tokens = {:?}", actual); diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index b5170e208..2774d43ef 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -68,6 +68,48 @@ fn parse_mssql_apply_join() { ); } +#[test] +fn parse_mssql_top_paren() { + let sql = "SELECT TOP (5) * FROM foo"; + let select = ms_and_generic().verified_only_select(sql); + let top = select.top.unwrap(); + assert_eq!(Some(Expr::Value(number("5"))), top.quantity); + assert!(!top.percent); +} + +#[test] +fn parse_mssql_top_percent() { + let sql = "SELECT TOP (5) PERCENT * FROM foo"; + let select = ms_and_generic().verified_only_select(sql); + let top = select.top.unwrap(); + assert_eq!(Some(Expr::Value(number("5"))), top.quantity); + assert!(top.percent); +} + +#[test] +fn parse_mssql_top_with_ties() { + let sql = "SELECT TOP (5) WITH TIES * FROM foo"; + let select = ms_and_generic().verified_only_select(sql); + let top = select.top.unwrap(); + assert_eq!(Some(Expr::Value(number("5"))), top.quantity); + assert!(top.with_ties); +} + +#[test] +fn parse_mssql_top_percent_with_ties() { + let sql = "SELECT TOP (10) PERCENT WITH TIES * FROM foo"; + let select = ms_and_generic().verified_only_select(sql); + let top = select.top.unwrap(); + assert_eq!(Some(Expr::Value(number("10"))), top.quantity); + assert!(top.percent); +} + +#[test] +fn parse_mssql_top() { + let sql = "SELECT TOP 5 bar, baz FROM foo"; + let _ = ms_and_generic().one_statement_parses_to(sql, "SELECT TOP (5) bar, baz FROM foo"); +} + fn ms() -> TestedDialects { TestedDialects { dialects: vec![Box::new(MsSqlDialect {})], diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index ce9d0053b..cc6433322 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -77,7 +77,7 @@ fn parse_show_columns() { Statement::ShowColumns { extended: false, full: false, - table_name: table_name.clone(), + table_name: table_name, filter: Some(ShowStatementFilter::Where( mysql_and_generic().verified_expr("1 = 2") )),