diff --git a/src/db/database.rs b/src/db/database.rs index a28fe26..da48369 100644 --- a/src/db/database.rs +++ b/src/db/database.rs @@ -1,5 +1,5 @@ use crate::db::table::{Table, Value}; -use crate::interpreter::ast::{SqlStatement, CreateTableStatement, InsertIntoStatement, SelectStatement, DeleteStatement, UpdateStatement}; +use crate::interpreter::ast::{SqlStatement, CreateTableStatement, InsertIntoStatement, SelectStatement, DeleteStatement, UpdateStatement, SelectStatementStackElement}; use crate::db::table::select; use crate::db::table::insert; use crate::db::table::delete; @@ -27,9 +27,19 @@ impl Database { self.insert_into_table(statement)?; Ok(None) }, - SqlStatement::Select(statement) => { - let rows = self.select_from_table(statement)?; - Ok(Some(rows)) + SqlStatement::Select(mut statement) => { + let select_statement = statement.elements.pop(); + if let Some(select_statement) = select_statement { + match select_statement { + SelectStatementStackElement::SelectStatement(select_statement) => { + let rows = self.select_from_table(select_statement)?; + Ok(Some(rows)) + } + _ => Err(format!("Expected select statement, got {:?}", select_statement)), + } + } else { + Ok(None) + } }, SqlStatement::UpdateStatement(statement) => { self.update_table(statement)?; diff --git a/src/interpreter/ast/helpers/mod.rs b/src/interpreter/ast/helpers/mod.rs index 27c6bf6..b9bdcce 100644 --- a/src/interpreter/ast/helpers/mod.rs +++ b/src/interpreter/ast/helpers/mod.rs @@ -2,4 +2,5 @@ pub mod where_stack; pub mod where_condition; pub mod order_by_clause; pub mod limit_clause; -pub mod common; \ No newline at end of file +pub mod common; +pub mod select_statement; \ No newline at end of file diff --git a/src/interpreter/ast/select_statement.rs b/src/interpreter/ast/helpers/select_statement.rs similarity index 86% rename from src/interpreter/ast/select_statement.rs rename to src/interpreter/ast/helpers/select_statement.rs index 7d3d734..7c708cf 100644 --- a/src/interpreter/ast/select_statement.rs +++ b/src/interpreter/ast/helpers/select_statement.rs @@ -1,15 +1,15 @@ use crate::{interpreter::{ ast::{ - parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereStackElement, + parser::Parser, SelectStatement, SelectStatementColumns, WhereStackElement, helpers::{ - common::{expect_token_type, tokens_to_identifier_list, get_table_name}, + common::{tokens_to_identifier_list, get_table_name}, order_by_clause::get_order_by, where_stack::get_where_clause, limit_clause::get_limit } }, tokenizer::token::TokenTypes }}; -pub fn build(parser: &mut Parser) -> Result { +pub fn get_statement(parser: &mut Parser) -> Result { parser.advance()?; let columns = get_columns(parser)?; let table_name = get_table_name(parser)?; @@ -18,15 +18,13 @@ pub fn build(parser: &mut Parser) -> Result { let order_by_clause = get_order_by(parser)?; let limit_clause = get_limit(parser)?; - // Ensure SemiColon - expect_token_type(parser, TokenTypes::SemiColon)?; - return Ok(SqlStatement::Select(SelectStatement { - table_name: table_name, - columns: columns, - where_clause: where_clause, - order_by_clause: order_by_clause, - limit_clause: limit_clause, - })); + return Ok(SelectStatement { + table_name: table_name, + columns: columns, + where_clause: where_clause, + order_by_clause: order_by_clause, + limit_clause: limit_clause, + }); } fn get_columns(parser: &mut Parser) -> Result { @@ -65,16 +63,16 @@ mod tests { token(TokenTypes::SemiColon, ";"), ]; let mut parser = Parser::new(tokens); - let result = build(&mut parser); + let result = get_statement(&mut parser); assert!(result.is_ok()); let statement = result.unwrap(); - assert_eq!(statement, SqlStatement::Select(SelectStatement { + assert_eq!(statement, SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::All, where_clause: None, order_by_clause: None, limit_clause: None, - })); + }); } #[test] @@ -88,10 +86,10 @@ mod tests { token(TokenTypes::SemiColon, ";"), ]; let mut parser = Parser::new(tokens); - let result = build(&mut parser); + let result = get_statement(&mut parser); assert!(result.is_ok()); let statement = result.unwrap(); - assert_eq!(statement, SqlStatement::Select(SelectStatement { + assert_eq!(statement, SelectStatement { table_name: "guests".to_string(), columns: SelectStatementColumns::Specific(vec![ "id".to_string(), @@ -99,7 +97,7 @@ mod tests { where_clause: None, order_by_clause: None, limit_clause: None, - })); + }); } #[test] @@ -115,10 +113,10 @@ mod tests { token(TokenTypes::SemiColon, ";"), ]; let mut parser = Parser::new(tokens); - let result = build(&mut parser); + let result = get_statement(&mut parser); assert!(result.is_ok()); let statement = result.unwrap(); - assert_eq!(statement, SqlStatement::Select(SelectStatement { + assert_eq!(statement, SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::Specific(vec![ "id".to_string(), @@ -127,7 +125,7 @@ mod tests { where_clause: None, order_by_clause: None, limit_clause: None, - })); + }); } #[test] @@ -158,10 +156,10 @@ mod tests { token(TokenTypes::SemiColon, ";"), ]; let mut parser = Parser::new(tokens); - let result = build(&mut parser); + let result = get_statement(&mut parser); assert!(result.is_ok()); let statement = result.unwrap(); - assert_eq!(statement, SqlStatement::Select(SelectStatement { + let expected = SelectStatement { table_name: "guests".to_string(), columns: SelectStatementColumns::Specific(vec![ "id".to_string(), @@ -191,6 +189,7 @@ mod tests { limit: Value::Integer(10), offset: Some(Value::Integer(5)), }), - })); + }; + assert_eq!(expected, statement); } } \ No newline at end of file diff --git a/src/interpreter/ast/helpers/where_stack.rs b/src/interpreter/ast/helpers/where_stack.rs index 9c47727..9a77c4d 100644 --- a/src/interpreter/ast/helpers/where_stack.rs +++ b/src/interpreter/ast/helpers/where_stack.rs @@ -27,7 +27,7 @@ pub fn get_where_clause(parser: &mut Parser) -> Result { match where_stack_element { @@ -131,7 +131,7 @@ pub fn get_where_clause(parser: &mut Parser) -> Result Result, String> { +fn get_where_condition(parser: &mut Parser, operator_stack: &Vec) -> Result, String> { let token = parser.current_token()?; match token.token_type { // Logical operators and parentheses @@ -152,8 +152,18 @@ fn get_where_condition(parser: &mut Parser) -> Result, return Ok(Some(WhereStackElement::Parentheses(Parentheses::Left))) }, TokenTypes::RightParen => { - parser.advance()?; - return Ok(Some(WhereStackElement::Parentheses(Parentheses::Right))) + // TODO improve this check. + let has_matching_left_paren = operator_stack.iter().any(|op| { + matches!(op, WhereStackOperators::Parentheses(Parentheses::Left)) + }); + + if has_matching_left_paren { + parser.advance()?; + return Ok(Some(WhereStackElement::Parentheses(Parentheses::Right))) + } else { + // We may have a mismatched parenthesis from the UNION STATEMENTs causing this. + return Ok(None); + } }, // Conditions TokenTypes::Identifier | TokenTypes::IntLiteral | TokenTypes::RealLiteral | TokenTypes::String | TokenTypes::Blob | TokenTypes::Null => { @@ -398,7 +408,8 @@ mod tests { } #[test] - fn returns_error_for_extra_closing_parenthesis() { + fn does_not_return_error_for_extra_closing_parenthesis() { + // This extra closing parenthesis could be from the UNION STATEMENTs and would error on that level. // WHERE (id = 1 OR name = "John")); (extra closing parenthesis) let tokens = vec![ token(TokenTypes::Where, "WHERE"), @@ -416,9 +427,13 @@ mod tests { ]; let mut parser = Parser::new(tokens); let result = get_where_clause(&mut parser); - assert!(result.is_err()); - assert_eq!(result.unwrap_err(), "Mismatched parentheses found."); - assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Some(vec![ + simple_condition("id", Operator::Equals, Value::Integer(1)), + simple_condition("name", Operator::Equals, Value::Text("John".to_string())), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + ])); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::RightParen); } #[test] diff --git a/src/interpreter/ast/mod.rs b/src/interpreter/ast/mod.rs index 1b76159..d401a9e 100644 --- a/src/interpreter/ast/mod.rs +++ b/src/interpreter/ast/mod.rs @@ -4,7 +4,7 @@ use crate::db::table::{ColumnDefinition, Value}; mod create_statement; mod insert_statement; mod parser; -mod select_statement; +mod select_statement_stack; mod update_statement; mod delete_statement; mod helpers; @@ -15,7 +15,7 @@ mod test_utils; pub enum SqlStatement { CreateTable(CreateTableStatement), InsertInto(InsertIntoStatement), - Select(SelectStatement), + Select(SelectStatementStack), UpdateStatement(UpdateStatement), DeleteStatement(DeleteStatement), } @@ -33,6 +33,41 @@ pub struct InsertIntoStatement { pub values: Vec>, } +#[derive(Debug, PartialEq)] +pub struct SelectStatementStack { + pub elements: Vec, +} + +#[derive(Debug, PartialEq)] +pub enum SelectStatementStackElement { + SelectStatement(SelectStatement), + SetOperator(SetOperator), +} + +#[derive(Debug, PartialEq)] +pub enum SelectStackOperators { + SetOperator(SetOperator), + Parentheses(Parentheses), +} + +#[derive(Debug, PartialEq)] +pub enum SetOperator { + Union, + UnionAll, + Intersect, + Except, +} + +impl SetOperator { + pub fn is_greater_precedence(&self, other: &SetOperator) -> bool { + match (self, other) { + (SetOperator::Intersect, SetOperator::Intersect) => false, + (SetOperator::Intersect, _) => true, + (_, _) => false, + } + } +} + #[derive(Debug, PartialEq)] pub struct SelectStatement { pub table_name: String, @@ -188,7 +223,7 @@ impl StatementBuilder for DefaultStatementBuilder { } fn build_select(&self, parser: &mut parser::Parser) -> Result { - select_statement::build(parser) + select_statement_stack::build(parser) } fn build_update(&self, parser: &mut parser::Parser) -> Result { @@ -284,12 +319,14 @@ mod tests { assert!(result[0].is_ok()); assert!(result[1].is_ok()); let expected = vec![ - Ok(SqlStatement::Select(SelectStatement { - table_name: "users".to_string(), - columns: SelectStatementColumns::All, - where_clause: None, - order_by_clause: None, - limit_clause: None, + Ok(SqlStatement::Select(SelectStatementStack { + elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { + table_name: "users".to_string(), + columns: SelectStatementColumns::All, + where_clause: None, + order_by_clause: None, + limit_clause: None, + })], })), Ok(SqlStatement::InsertInto(InsertIntoStatement { table_name: "users".to_string(), @@ -362,12 +399,14 @@ mod tests { assert!(result[0].is_ok()); assert!(result[1].is_ok()); let expected = vec![ - Ok(SqlStatement::Select(SelectStatement { - table_name: "users".to_string(), - columns: SelectStatementColumns::All, - where_clause: None, - order_by_clause: None, - limit_clause: None, + Ok(SqlStatement::Select(SelectStatementStack { + elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { + table_name: "users".to_string(), + columns: SelectStatementColumns::All, + where_clause: None, + order_by_clause: None, + limit_clause: None, + })], })), Ok(SqlStatement::InsertInto(InsertIntoStatement { table_name: "users".to_string(), diff --git a/src/interpreter/ast/parser.rs b/src/interpreter/ast/parser.rs index 9494c74..3fc0be5 100644 --- a/src/interpreter/ast/parser.rs +++ b/src/interpreter/ast/parser.rs @@ -23,6 +23,13 @@ impl<'a> Parser<'a> { return Ok(&self.tokens[self.current]); } + pub fn peek_token(&self) -> Result<&Token<'a>, String> { + if self.current + 1 >= self.tokens.len() { + return Err(self.format_error()); + } + return Ok(&self.tokens[self.current + 1]); + } + pub fn advance(&mut self) -> Result<(), String> { if let Ok(token) = self.current_token() { if token.token_type == TokenTypes::SemiColon { @@ -68,19 +75,22 @@ impl<'a> Parser<'a> { } pub fn next_statement(&mut self, builder: &dyn StatementBuilder) -> Option> { - match self.current_token() { - Ok(token) => match token.token_type { - TokenTypes::Create => Some(builder.build_create(self)), - TokenTypes::Insert => Some(builder.build_insert(self)), - TokenTypes::Select => Some(builder.build_select(self)), - TokenTypes::Update => Some(builder.build_update(self)), - TokenTypes::Delete => Some(builder.build_delete(self)), - TokenTypes::EOF => None, + match (&self.current_token(), &self.peek_token()) { + (Ok(token), Ok(peek_token)) => match (&token.token_type, &peek_token.token_type) { + (TokenTypes::Create, _) => Some(builder.build_create(self)), + (TokenTypes::Insert, _) => Some(builder.build_insert(self)), + (TokenTypes::Select, _) | (TokenTypes::LeftParen, TokenTypes::Select) => Some(builder.build_select(self)), + (TokenTypes::Update, _) => Some(builder.build_update(self)), + (TokenTypes::Delete, _) => Some(builder.build_delete(self)), _ => { Some(Err(self.format_error())) } }, - Err(error) => Some(Err(error)), + (Ok(token), Err(_)) => match token.token_type { + TokenTypes::EOF => None, + _ => Some(Err(self.format_error_nearby())), + }, + _ => Some(Err(self.format_error())), } } } @@ -89,7 +99,7 @@ impl<'a> Parser<'a> { #[cfg(test)] mod tests { use super::*; - use crate::interpreter::ast::{CreateTableStatement, InsertIntoStatement, SelectStatement, SelectStatementColumns}; + use crate::interpreter::ast::{CreateTableStatement, InsertIntoStatement, SelectStatement, SelectStatementColumns, SelectStatementStack, SelectStatementStackElement}; use crate::interpreter::ast::test_utils::{token_with_location, token}; #[test] @@ -133,12 +143,14 @@ mod tests { fn build_select(&self, parser: &mut Parser) -> Result { parser.advance()?; parser.advance_past_semicolon()?; - return Ok(SqlStatement::Select(SelectStatement { - table_name: "users".to_string(), - columns: SelectStatementColumns::All, - where_clause: None, - order_by_clause: None, - limit_clause: None, + return Ok(SqlStatement::Select(SelectStatementStack { + elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { + table_name: "users".to_string(), + columns: SelectStatementColumns::All, + where_clause: None, + order_by_clause: None, + limit_clause: None, + })], })); } @@ -183,12 +195,14 @@ mod tests { // Select let result = parser.next_statement(builder); - let expected = Some(Ok(SqlStatement::Select(SelectStatement { - table_name: "users".to_string(), - columns: SelectStatementColumns::All, - where_clause: None, - order_by_clause: None, - limit_clause: None, + let expected = Some(Ok(SqlStatement::Select(SelectStatementStack { + elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { + table_name: "users".to_string(), + columns: SelectStatementColumns::All, + where_clause: None, + order_by_clause: None, + limit_clause: None, + })], }))); assert_eq!(result, expected); diff --git a/src/interpreter/ast/select_statement_stack.rs b/src/interpreter/ast/select_statement_stack.rs new file mode 100644 index 0000000..63c36fa --- /dev/null +++ b/src/interpreter/ast/select_statement_stack.rs @@ -0,0 +1,254 @@ +use crate::interpreter::ast::{parser::Parser, SqlStatement, SelectStatementStack, SelectStatementStackElement, SetOperator, SelectStackOperators}; +use crate::interpreter::ast::helpers::select_statement; +use crate::interpreter::ast::Parentheses; +use crate::interpreter::tokenizer::token::TokenTypes; + +// Returns a SelectStatementStack which is an RPN representation of the SELECT statements and set operators. +pub fn build(parser: &mut Parser) -> Result { + let mut select_statement_stack: Vec = vec![]; + let mut set_operator_stack: Vec = vec![]; + + loop { + let token = parser.current_token()?; + match token.token_type { + TokenTypes::Select => { + let statement = select_statement::get_statement(parser)?; + select_statement_stack.push(SelectStatementStackElement::SelectStatement(statement)); + } + TokenTypes::LeftParen => { + set_operator_stack.push(SelectStackOperators::Parentheses(Parentheses::Left)); + parser.advance()?; + } + TokenTypes::RightParen => { + while let Some(current_set_operator) = set_operator_stack.pop() { + if let SelectStackOperators::Parentheses(_) = current_set_operator { + break; + } + else if let SelectStackOperators::SetOperator(set_operator) = current_set_operator { + select_statement_stack.push(SelectStatementStackElement::SetOperator(set_operator)); + } + else { + return Err("Mismatched parentheses found.".to_string()); + } + } + parser.advance()?; + } + TokenTypes::Union | TokenTypes::Except => { + let set_operator = get_set_operator(parser)?; + while let Some(current_set_operator) = set_operator_stack.pop() { + if let SelectStackOperators::Parentheses(parentheses) = current_set_operator { + set_operator_stack.push(SelectStackOperators::Parentheses(parentheses)); + break; + } + else if let SelectStackOperators::SetOperator(current_set_operator) = current_set_operator { + select_statement_stack.push(SelectStatementStackElement::SetOperator(current_set_operator)); + } + } + set_operator_stack.push(SelectStackOperators::SetOperator(set_operator)); + } + TokenTypes::Intersect => { + let set_operator = get_set_operator(parser)?; + while let Some(current_set_operator) = set_operator_stack.pop() { + if let SelectStackOperators::SetOperator(current_set_operator) = current_set_operator { + if set_operator.is_greater_precedence(¤t_set_operator) { + set_operator_stack.push(SelectStackOperators::SetOperator(current_set_operator)); + break; + } + else { + select_statement_stack.push(SelectStatementStackElement::SetOperator(current_set_operator)); + } + } + else { + set_operator_stack.push(current_set_operator); + break; + } + } + set_operator_stack.push(SelectStackOperators::SetOperator(set_operator)); + } + TokenTypes::SemiColon => break, + _ => return Err(parser.format_error()), + } + } + + while let Some(current_set_operator) = set_operator_stack.pop() { + if let SelectStackOperators::SetOperator(set_operator) = current_set_operator { + select_statement_stack.push(SelectStatementStackElement::SetOperator(set_operator)); + } + else { + return Err("Mismatched parentheses found.".to_string()); + } + } + + return Ok(SqlStatement::Select(SelectStatementStack { + elements: select_statement_stack, + })); +} + +fn get_set_operator(parser: &mut Parser) -> Result { + let token = parser.current_token()?; + let set_operator = match token.token_type { + TokenTypes::Union => { + if parser.peek_token()?.token_type == TokenTypes::All { + parser.advance()?; + Ok(SetOperator::UnionAll) + } else { + Ok(SetOperator::Union) + } + }, + TokenTypes::Except => { + Ok(SetOperator::Except) + }, + TokenTypes::Intersect => { + Ok(SetOperator::Intersect) + }, + _ => Err("Expected token type: Union, Except, or Intersect was not found".to_string()), + }; + parser.advance()?; + return set_operator; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::interpreter::ast::test_utils::token; + use crate::interpreter::ast::SelectStatement; + use crate::interpreter::ast::SetOperator; + use crate::interpreter::ast::SelectStatementColumns; + use crate::interpreter::ast::WhereStackElement; + use crate::interpreter::ast::WhereCondition; + use crate::interpreter::ast::Operand; + use crate::interpreter::ast::Operator; + use crate::db::table::Value; + use crate::interpreter::tokenizer::token::TokenTypes; + use crate::interpreter::tokenizer::scanner::Token; + + fn simple_select_statement_tokens(id: &'static str) -> Vec> { + vec![ + token(TokenTypes::Select, "SELECT"), + token(TokenTypes::Asterisk, "*"), + token(TokenTypes::From, "FROM"), + token(TokenTypes::Identifier, "users"), + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, id), + ] + } + + fn expected_simple_select_statement(id: i64) -> SelectStatementStackElement { + SelectStatementStackElement::SelectStatement(SelectStatement { + table_name: "users".to_string(), + columns: SelectStatementColumns::All, + where_clause: Some(vec![WhereStackElement::Condition(WhereCondition { + l_side: Operand::Identifier("id".to_string()), + operator: Operator::Equals, + r_side: Operand::Value(Value::Integer(id)), + })]), + order_by_clause: None, + limit_clause: None, + }) + } + + + #[test] + fn simple_select_statement_is_generated_correctly() { + // SELECT * FROM users WHERE id = 1; + let mut tokens = simple_select_statement_tokens("1"); + tokens.append(&mut vec![token(TokenTypes::SemiColon, ";")]); + let mut parser = Parser::new(tokens); + let result = build(&mut parser); + assert!(result.is_ok()); + let statement = result.unwrap(); + let expected = SqlStatement::Select(SelectStatementStack { + elements: vec![expected_simple_select_statement(1)], + }); + assert_eq!(expected, statement); + } + + #[test] + fn select_statement_with_set_operator_is_generated_correctly() { + // SELECT * FROM users WHERE id = 1 UNION ALL SELECT * FROM users WHERE id = 2; + let mut tokens = simple_select_statement_tokens("1"); + tokens.append(&mut vec![token(TokenTypes::Union, "UNION"), token(TokenTypes::All, "ALL")]); + tokens.append(&mut simple_select_statement_tokens("2")); + tokens.append(&mut vec![token(TokenTypes::SemiColon, ";")]); + let mut parser = Parser::new(tokens); + let result = build(&mut parser); + println!("{:?}", result); + assert!(result.is_ok()); + let statement = result.unwrap(); + let expected = SqlStatement::Select(SelectStatementStack { + elements: vec![ + expected_simple_select_statement(1), + expected_simple_select_statement(2), + SelectStatementStackElement::SetOperator(SetOperator::UnionAll), + ], + }); + assert_eq!(expected, statement); + } + + #[test] + fn select_statement_with_multiple_set_operators_is_generated_correctly() { + // SELECT 1 ... UNION ALL SELECT 2 ... INTERSECT SELECT 3 ... EXCEPT SELECT 4 ...; + let mut tokens = simple_select_statement_tokens("1"); + tokens.append(&mut vec![token(TokenTypes::Union, "UNION")]); + tokens.append(&mut simple_select_statement_tokens("2")); + tokens.append(&mut vec![token(TokenTypes::Intersect, "INTERSECT")]); + tokens.append(&mut simple_select_statement_tokens("3")); + tokens.append(&mut vec![token(TokenTypes::Except, "EXCEPT")]); + tokens.append(&mut simple_select_statement_tokens("4")); + tokens.append(&mut vec![token(TokenTypes::SemiColon, ";")]); + let mut parser = Parser::new(tokens); + let result = build(&mut parser); + println!("{:?}", result); + assert!(result.is_ok()); + let statement = result.unwrap(); + let expected = SqlStatement::Select(SelectStatementStack { + elements: vec![ + expected_simple_select_statement(1), + expected_simple_select_statement(2), + expected_simple_select_statement(3), + SelectStatementStackElement::SetOperator(SetOperator::Intersect), + SelectStatementStackElement::SetOperator(SetOperator::Union), + expected_simple_select_statement(4), + SelectStatementStackElement::SetOperator(SetOperator::Except), + ], + }); + assert_eq!(expected, statement); + } + + #[test] + fn select_statement_with_multiple_set_operators_and_parentheses_is_generated_correctly() { + // (SELECT 1 ... UNION ALL SELECT 2 ...) INTERSECT (SELECT 3 ... EXCEPT SELECT 4 ...); + let mut tokens = vec![token(TokenTypes::LeftParen, "(")]; + tokens.append(&mut simple_select_statement_tokens("1")); + tokens.append(&mut vec![token(TokenTypes::Union, "UNION")]); + tokens.append(&mut vec![token(TokenTypes::All, "ALL")]); + tokens.append(&mut simple_select_statement_tokens("2")); + tokens.append(&mut vec![token(TokenTypes::RightParen, ")")]); + tokens.append(&mut vec![token(TokenTypes::Intersect, "INTERSECT")]); + tokens.append(&mut vec![token(TokenTypes::LeftParen, "(")]); + tokens.append(&mut simple_select_statement_tokens("3")); + tokens.append(&mut vec![token(TokenTypes::Except, "EXCEPT")]); + tokens.append(&mut simple_select_statement_tokens("4")); + tokens.append(&mut vec![token(TokenTypes::RightParen, ")")]); + tokens.append(&mut vec![token(TokenTypes::SemiColon, ";")]); + let mut parser = Parser::new(tokens); + let result = build(&mut parser); + println!("{:?}", result); + assert!(result.is_ok()); + let statement = result.unwrap(); + let expected = SqlStatement::Select(SelectStatementStack { + elements: vec![ + expected_simple_select_statement(1), + expected_simple_select_statement(2), + SelectStatementStackElement::SetOperator(SetOperator::UnionAll), + expected_simple_select_statement(3), + expected_simple_select_statement(4), + SelectStatementStackElement::SetOperator(SetOperator::Except), + SelectStatementStackElement::SetOperator(SetOperator::Intersect), + ], + }); + assert_eq!(expected, statement); + } +} \ No newline at end of file diff --git a/src/interpreter/tokenizer/scanner.rs b/src/interpreter/tokenizer/scanner.rs index b5b48c1..b83f3a7 100644 --- a/src/interpreter/tokenizer/scanner.rs +++ b/src/interpreter/tokenizer/scanner.rs @@ -159,6 +159,9 @@ impl<'a> Scanner<'a> { slice if slice.eq_ignore_ascii_case("UNION") => TokenTypes::Union, slice if slice.eq_ignore_ascii_case("LIMIT") => TokenTypes::Limit, slice if slice.eq_ignore_ascii_case("OFFSET") => TokenTypes::Offset, + slice if slice.eq_ignore_ascii_case("UNION") => TokenTypes::Union, + slice if slice.eq_ignore_ascii_case("INTERSECT") => TokenTypes::Intersect, + slice if slice.eq_ignore_ascii_case("EXCEPT") => TokenTypes::Except, slice if slice.eq_ignore_ascii_case("AND") => TokenTypes::And, slice if slice.eq_ignore_ascii_case("OR") => TokenTypes::Or, slice if slice.eq_ignore_ascii_case("IN") => TokenTypes::In, diff --git a/src/interpreter/tokenizer/token.rs b/src/interpreter/tokenizer/token.rs index 517b11f..1706f7b 100644 --- a/src/interpreter/tokenizer/token.rs +++ b/src/interpreter/tokenizer/token.rs @@ -9,8 +9,8 @@ pub enum TokenTypes { Primary, Key, Not, Unique, Default, AutoIncrement, // Clauses Order, By, Group, Having, Distinct, All, As, Asc, Desc, - Inner, Left, Right, Full, Outer, Join, On, Union, - Limit, Offset, + Inner, Left, Right, Full, Outer, Join, On, + Limit, Offset, Union, Intersect, Except, // Logical Operators And, Or, In, Exists, Case, When, Then, Else, End, Is,