Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub use self::ddl::{
pub use self::operator::{BinaryOperator, UnaryOperator};
pub use self::query::{
Cte, Fetch, Join, JoinConstraint, JoinOperator, Offset, OffsetRows, OrderByExpr, Query, Select,
SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values,
SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, Values, With,
};
pub use self::value::{DateTimeField, Value};

Expand Down
26 changes: 22 additions & 4 deletions src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Query {
/// WITH (common table expressions, or CTEs)
pub ctes: Vec<Cte>,
/// SELECT or UNION / EXCEPT / INTECEPT
pub with: Option<With>,
/// SELECT or UNION / EXCEPT / INTERSECT
pub body: SetExpr,
/// ORDER BY
pub order_by: Vec<OrderByExpr>,
Expand All @@ -35,8 +35,8 @@ pub struct Query {

impl fmt::Display for Query {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if !self.ctes.is_empty() {
write!(f, "WITH {} ", display_comma_separated(&self.ctes))?;
if let Some(ref with) = self.with {
write!(f, "{} ", with)?;
}
write!(f, "{}", self.body)?;
if !self.order_by.is_empty() {
Expand Down Expand Up @@ -157,6 +157,24 @@ impl fmt::Display for Select {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct With {
pub recursive: bool,
pub cte_tables: Vec<Cte>,
}

impl fmt::Display for With {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"WITH {}{}",
if self.recursive { "RECURSIVE " } else { "" },
display_comma_separated(&self.cte_tables)
)
}
}

/// A single CTE (used after `WITH`): `alias [(col1, col2, ...)] AS ( query )`
/// The names in the column list before `AS`, when specified, replace the names
/// of the columns returned by the query. The parser does not validate that the
Expand Down
12 changes: 7 additions & 5 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1795,11 +1795,13 @@ impl<'a> Parser<'a> {
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
/// expect the initial keyword to be already consumed
pub fn parse_query(&mut self) -> Result<Query, ParserError> {
let ctes = if self.parse_keyword(Keyword::WITH) {
// TODO: optional RECURSIVE
self.parse_comma_separated(Parser::parse_cte)?
let with = if self.parse_keyword(Keyword::WITH) {
Some(With {
recursive: self.parse_keyword(Keyword::RECURSIVE),
cte_tables: self.parse_comma_separated(Parser::parse_cte)?,
})
} else {
vec![]
None
};

let body = self.parse_query_body(0)?;
Expand Down Expand Up @@ -1829,7 +1831,7 @@ impl<'a> Parser<'a> {
};

Ok(Query {
ctes,
with,
body,
limit,
order_by,
Expand Down
5 changes: 1 addition & 4 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,7 @@ impl<'a> Tokenizer<'a> {
// numbers
'0'..='9' => {
// TODO: https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#unsigned-numeric-literal
let s = peeking_take_while(chars, |ch| match ch {
'0'..='9' | '.' => true,
_ => false,
});
let s = peeking_take_while(chars, |ch| matches!(ch, '0'..='9' | '.'));
Ok(Some(Token::Number(s)))
}
// punctuation
Expand Down
47 changes: 42 additions & 5 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2389,7 +2389,7 @@ fn parse_ctes() {

fn assert_ctes_in_select(expected: &[&str], sel: &Query) {
for (i, exp) in expected.iter().enumerate() {
let Cte { alias, query } = &sel.ctes[i];
let Cte { alias, query } = &sel.with.as_ref().unwrap().cte_tables[i];
assert_eq!(*exp, query.to_string());
assert_eq!(
if i == 0 {
Expand Down Expand Up @@ -2432,7 +2432,7 @@ fn parse_ctes() {
// CTE in a CTE...
let sql = &format!("WITH outer_cte AS ({}) SELECT * FROM outer_cte", with);
let select = verified_query(sql);
assert_ctes_in_select(&cte_sqls, &only(&select.ctes).query);
assert_ctes_in_select(&cte_sqls, &only(&select.with.unwrap().cte_tables).query);
}

#[test]
Expand All @@ -2441,10 +2441,47 @@ fn parse_cte_renamed_columns() {
let query = all_dialects().verified_query(sql);
assert_eq!(
vec![Ident::new("col1"), Ident::new("col2")],
query.ctes.first().unwrap().alias.columns
query
.with
.unwrap()
.cte_tables
.first()
.unwrap()
.alias
.columns
);
}

#[test]
fn parse_recursive_cte() {
let cte_query = "SELECT 1 UNION ALL SELECT val + 1 FROM nums WHERE val < 10".to_owned();
let sql = &format!(
"WITH RECURSIVE nums (val) AS ({}) SELECT * FROM nums",
cte_query
);

let cte_query = verified_query(&cte_query);
let query = verified_query(sql);

let with = query.with.as_ref().unwrap();
assert!(with.recursive);
assert_eq!(with.cte_tables.len(), 1);
let expected = Cte {
alias: TableAlias {
name: Ident {
value: "nums".to_string(),
quote_style: None,
},
columns: vec![Ident {
value: "val".to_string(),
quote_style: None,
}],
},
query: cte_query,
};
assert_eq!(with.cte_tables.first().unwrap(), &expected);
}

#[test]
fn parse_derived_tables() {
let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b";
Expand Down Expand Up @@ -3266,8 +3303,8 @@ fn parse_drop_index() {
fn all_keywords_sorted() {
// assert!(ALL_KEYWORDS.is_sorted())
let mut copy = Vec::from(ALL_KEYWORDS);
copy.sort();
assert!(copy == ALL_KEYWORDS)
copy.sort_unstable();
assert_eq!(copy, ALL_KEYWORDS)
}

fn parse_sql_statements(sql: &str) -> Result<Vec<Statement>, ParserError> {
Expand Down