diff --git a/src/interpreter/ast/delete_statement.rs b/src/interpreter/ast/delete_statement.rs index 030df8b..5bd8f56 100644 --- a/src/interpreter/ast/delete_statement.rs +++ b/src/interpreter/ast/delete_statement.rs @@ -3,7 +3,7 @@ use crate::interpreter::{ parser::Parser, SqlStatement, DeleteStatement, helpers::{ common::{expect_token_type, get_table_name}, - order_by_clause::get_order_by, where_stack::get_where_clause, limit_clause::get_limit + order_by_clause::get_order_by, where_clause::get_where_clause, limit_clause::get_limit } }, tokenizer::token::TokenTypes diff --git a/src/interpreter/ast/helpers/mod.rs b/src/interpreter/ast/helpers/mod.rs index b9bdcce..d640337 100644 --- a/src/interpreter/ast/helpers/mod.rs +++ b/src/interpreter/ast/helpers/mod.rs @@ -1,5 +1,4 @@ -pub mod where_stack; -pub mod where_condition; +pub mod where_clause; pub mod order_by_clause; pub mod limit_clause; pub mod common; diff --git a/src/interpreter/ast/helpers/select_statement.rs b/src/interpreter/ast/helpers/select_statement.rs index c0c7955..782d7b0 100644 --- a/src/interpreter/ast/helpers/select_statement.rs +++ b/src/interpreter/ast/helpers/select_statement.rs @@ -3,7 +3,7 @@ use crate::{interpreter::{ parser::Parser, SelectStatement, SelectStatementColumns, WhereStackElement, SelectMode, helpers::{ common::{tokens_to_identifier_list, get_table_name, expect_token_type}, - order_by_clause::get_order_by, where_stack::get_where_clause, limit_clause::get_limit + order_by_clause::get_order_by, where_clause::get_where_clause, limit_clause::get_limit } }, tokenizer::token::TokenTypes diff --git a/src/interpreter/ast/helpers/where_clause/expected_token_matches_current.rs b/src/interpreter/ast/helpers/where_clause/expected_token_matches_current.rs new file mode 100644 index 0000000..ab80810 --- /dev/null +++ b/src/interpreter/ast/helpers/where_clause/expected_token_matches_current.rs @@ -0,0 +1,105 @@ +use crate::interpreter::ast::{WhereStackElement, LogicalOperator, Parentheses, parser::Parser}; + + +#[derive(PartialEq, Debug)] +pub enum WhereClauseExpectedNextToken { + ConditionLeftParenNot, + LogicalOperatorRightParen, +} + +// This function ensures that the current where stack element is correct based on the previous. +// Raises parser errors for strings like `WHERE NOT AND 1 = 1`, `WHERE 1 = 1 2 = 2`, or `WHERE ()`. +pub fn next_expected_token_from_current(expected_token: &WhereClauseExpectedNextToken, where_stack_element: &WhereStackElement, parser: &mut Parser) -> Result { + match where_stack_element { + WhereStackElement::Condition(_) => { + if *expected_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } + Ok(WhereClauseExpectedNextToken::LogicalOperatorRightParen) + }, + WhereStackElement::LogicalOperator(logical_operator) => { + match logical_operator { + LogicalOperator::Not => { + if *expected_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } + Ok(WhereClauseExpectedNextToken::ConditionLeftParenNot) + }, + _ => { + if *expected_token != WhereClauseExpectedNextToken::LogicalOperatorRightParen { + return Err(parser.format_error_nearby()); + } + Ok(WhereClauseExpectedNextToken::ConditionLeftParenNot) + } + } + }, + WhereStackElement::Parentheses(parentheses) => { + match parentheses { + Parentheses::Left => { + if *expected_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } + Ok(WhereClauseExpectedNextToken::ConditionLeftParenNot) + }, + Parentheses::Right => { + if *expected_token != WhereClauseExpectedNextToken::LogicalOperatorRightParen { + return Err(parser.format_error_nearby()); + } + Ok(WhereClauseExpectedNextToken::LogicalOperatorRightParen) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::interpreter::ast::{WhereCondition, Operand, Operator, Value}; + + #[test] + fn handles_next_condition_being_condition() { + let mut parser = Parser::new(vec![]); + let where_stack_element = WhereStackElement::Condition(WhereCondition {l_side: Operand::Identifier("id".to_string()),operator:Operator::Equals,r_side: Operand::Value(Value::Integer(1))}); + let expected_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_err()); + let expected_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_ok()); + } + + #[test] + fn handles_next_condition_being_logical_operator() { + let mut parser = Parser::new(vec![]); + // Not operator + let where_stack_element = WhereStackElement::LogicalOperator(LogicalOperator::Not); + let expected_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_ok()); + let expected_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_err()); + + // Other logical operator (I used AND but OR should be the same) + let where_stack_element = WhereStackElement::LogicalOperator(LogicalOperator::And); + let expected_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_err()); + let expected_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_ok()); + } + + #[test] + fn handles_next_condition_being_parentheses() { + let mut parser = Parser::new(vec![]); + // Left parentheses + let where_stack_element = WhereStackElement::Parentheses(Parentheses::Left); + let expected_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_ok()); + let expected_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_err()); + + // Right parentheses + let where_stack_element = WhereStackElement::Parentheses(Parentheses::Right); + let expected_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_err()); + let expected_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; + assert!(next_expected_token_from_current(&expected_token, &where_stack_element, &mut parser).is_ok()); + } +} \ No newline at end of file diff --git a/src/interpreter/ast/helpers/where_stack.rs b/src/interpreter/ast/helpers/where_clause/mod.rs similarity index 66% rename from src/interpreter/ast/helpers/where_stack.rs rename to src/interpreter/ast/helpers/where_clause/mod.rs index 9a77c4d..c789c36 100644 --- a/src/interpreter/ast/helpers/where_stack.rs +++ b/src/interpreter/ast/helpers/where_clause/mod.rs @@ -1,6 +1,11 @@ +mod where_condition; +mod expected_token_matches_current; +mod where_stack_element; +use expected_token_matches_current::{next_expected_token_from_current, WhereClauseExpectedNextToken}; + use crate::interpreter::{ast::{ - helpers::{common::expect_token_type, where_condition::get_condition}, - parser::Parser, LogicalOperator, WhereStackElement, WhereStackOperators, Parentheses}}; + helpers::{common::expect_token_type, where_clause::where_stack_element::get_where_stack_element}, + parser::Parser, WhereStackElement, WhereStackOperators, Parentheses}}; use crate::interpreter::tokenizer::token::TokenTypes; // The WhereStack is a the method that is used to store the order of operations with Reverse Polish Notation. @@ -9,14 +14,6 @@ use crate::interpreter::tokenizer::token::TokenTypes; // This is currently represented as stack of LogicalOperators, WhereConditions. // WhereConditions are currently represented as 'column operator value' // This will later be expanded to replace the WhereConditions with a generalized evaluation function. - -// We also validate the order using an enum of the current next expected token types. -#[derive(PartialEq, Debug)] -enum WhereClauseExpectedNextToken { - ConditionLeftParenNot, - LogicalOperatorRightParen, -} - pub fn get_where_clause(parser: &mut Parser) -> Result>, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); @@ -24,106 +21,59 @@ pub fn get_where_clause(parser: &mut Parser) -> Result = vec![]; let mut operator_stack: Vec = vec![]; - let mut expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + let mut expected_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; - loop { - let where_condition = get_where_condition(parser, &operator_stack)?; - match where_condition { - Some(where_stack_element) => { - match where_stack_element { - WhereStackElement::Condition(where_condition) => { - if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { - return Err(parser.format_error_nearby()); - } - expected_next_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; - where_stack.push(WhereStackElement::Condition(where_condition)); - }, - WhereStackElement::Parentheses(parentheses) => { - match parentheses { - Parentheses::Left => { - if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { - return Err(parser.format_error_nearby()); - } - operator_stack.push(WhereStackOperators::Parentheses(parentheses)); - }, - Parentheses::Right => { - if expected_next_token != WhereClauseExpectedNextToken::LogicalOperatorRightParen { - return Err(parser.format_error_nearby()); - } - expected_next_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; - loop { - let current_operator = operator_stack.pop(); - if let Some (current_operator) = current_operator { - match current_operator { - WhereStackOperators::Parentheses(Parentheses::Left) => { - break; - }, - WhereStackOperators::LogicalOperator(logical_operator) => { - where_stack.push(WhereStackElement::LogicalOperator(logical_operator)); - }, - WhereStackOperators::Parentheses(Parentheses::Right) => { - return Err("Mismatched parentheses found.".to_string()); - }, - } - } - else { - return Err("Mismatched parentheses found.".to_string()); - } - } - }, - } - }, - WhereStackElement::LogicalOperator(logical_operator) => { - match logical_operator { - LogicalOperator::Not => { - if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { - return Err(parser.format_error_nearby()); - } - expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; - } - _ => { - if expected_next_token != WhereClauseExpectedNextToken::LogicalOperatorRightParen { - return Err(parser.format_error_nearby()); - } - expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + while let Some(where_stack_element) = get_where_stack_element(parser, &operator_stack)? { + expected_token = next_expected_token_from_current(&expected_token, &where_stack_element, parser)?; + match where_stack_element { + WhereStackElement::Condition(condition) => where_stack.push(WhereStackElement::Condition(condition)), + WhereStackElement::Parentheses(parentheses) => { + if parentheses == Parentheses::Left { + operator_stack.push(WhereStackOperators::Parentheses(parentheses)); + continue; + } + while let Some(current_operator) = operator_stack.pop() { + match (current_operator, operator_stack.len()) { + (WhereStackOperators::LogicalOperator(_), 0) => return Err("Mismatched parentheses found.".to_string()), + (WhereStackOperators::Parentheses(Parentheses::Left), _) => break, + (WhereStackOperators::LogicalOperator(logical_operator), _) => where_stack.push(WhereStackElement::LogicalOperator(logical_operator)), + _ => unreachable!(), + } + } + }, + WhereStackElement::LogicalOperator(logical_operator) => { + loop { + let current_operator = match operator_stack.pop() { + Some(operator) => operator, + None => { + operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); + break; + }, + }; + match current_operator { + WhereStackOperators::LogicalOperator(current_logical_operator) => { + if !logical_operator.is_greater_precedence(¤t_logical_operator) { + where_stack.push(WhereStackElement::LogicalOperator(current_logical_operator)); } - } - loop { - let current_operator = if let Some(operator) = operator_stack.pop() { - operator - } else { + else { + operator_stack.push(WhereStackOperators::LogicalOperator(current_logical_operator)); operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); break; - }; - match current_operator { - WhereStackOperators::LogicalOperator(current_operator) => { - if logical_operator.is_greater_precedence(¤t_operator) { - operator_stack.push(WhereStackOperators::LogicalOperator(current_operator)); - operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); - break; - } - else { - where_stack.push(WhereStackElement::LogicalOperator(current_operator)); - } - }, - _ => { - operator_stack.push(current_operator); - operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); - break; - }, } - } + }, + _ => { + operator_stack.push(current_operator); + operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); + break; + }, } } } - None => break } } while let Some(operator) = operator_stack.pop() { match operator { - WhereStackOperators::LogicalOperator(logical_operator) => { - where_stack.push(WhereStackElement::LogicalOperator(logical_operator)); - }, + WhereStackOperators::LogicalOperator(_) => where_stack.push(operator.into_where_stack_element()), _ => return Err("Mismatched parentheses found.".to_string()), } } @@ -131,48 +81,6 @@ pub fn get_where_clause(parser: &mut Parser) -> Result) -> Result, String> { - let token = parser.current_token()?; - match token.token_type { - // Logical operators and parentheses - TokenTypes::And => { - parser.advance()?; - return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::And))) - }, - TokenTypes::Or => { - parser.advance()?; - return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::Or))) - }, - TokenTypes::Not => { - parser.advance()?; - return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::Not))) - }, - TokenTypes::LeftParen => { - parser.advance()?; - return Ok(Some(WhereStackElement::Parentheses(Parentheses::Left))) - }, - TokenTypes::RightParen => { - // TODO improve this check. - let has_matching_left_paren = operator_stack.iter().any(|op| { - matches!(op, WhereStackOperators::Parentheses(Parentheses::Left)) - }); - - if has_matching_left_paren { - parser.advance()?; - return Ok(Some(WhereStackElement::Parentheses(Parentheses::Right))) - } else { - // We may have a mismatched parenthesis from the UNION STATEMENTs causing this. - return Ok(None); - } - }, - // Conditions - TokenTypes::Identifier | TokenTypes::IntLiteral | TokenTypes::RealLiteral | TokenTypes::String | TokenTypes::Blob | TokenTypes::Null => { - return Ok(Some(WhereStackElement::Condition(get_condition(parser)?))); - } - _ => return Ok(None), - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/src/interpreter/ast/helpers/where_condition.rs b/src/interpreter/ast/helpers/where_clause/where_condition.rs similarity index 100% rename from src/interpreter/ast/helpers/where_condition.rs rename to src/interpreter/ast/helpers/where_clause/where_condition.rs diff --git a/src/interpreter/ast/helpers/where_clause/where_stack_element.rs b/src/interpreter/ast/helpers/where_clause/where_stack_element.rs new file mode 100644 index 0000000..44c6695 --- /dev/null +++ b/src/interpreter/ast/helpers/where_clause/where_stack_element.rs @@ -0,0 +1,33 @@ +use crate::interpreter::ast::{parser::Parser, WhereStackElement, WhereStackOperators, Parentheses, LogicalOperator}; +use crate::interpreter::tokenizer::token::TokenTypes; +use crate::interpreter::ast::helpers::where_clause::where_condition::get_condition; + + +pub fn get_where_stack_element(parser: &mut Parser, operator_stack: &Vec) -> Result, String> { + let token_type = &parser.current_token()?.token_type; + match token_type { + TokenTypes::And | TokenTypes::Or | TokenTypes::Not | TokenTypes::LeftParen | TokenTypes::RightParen => { + if token_type == &TokenTypes::RightParen && !operator_stack.contains(&WhereStackOperators::Parentheses(Parentheses::Left)) { + return Ok(None); // Mismatched parens can be caused by the UNION STATEMENTs. + } + let where_stack_element = token_type_to_where_stack_element(token_type); + parser.advance()?; + Ok(Some(where_stack_element)) + }, + TokenTypes::Identifier | TokenTypes::IntLiteral | TokenTypes::RealLiteral | TokenTypes::String | TokenTypes::Blob | TokenTypes::Null => { + return Ok(Some(WhereStackElement::Condition(get_condition(parser)?))); + } + _ => return Ok(None), + } +} + +fn token_type_to_where_stack_element(token_type: &TokenTypes) -> WhereStackElement { + match token_type { + TokenTypes::And => WhereStackElement::LogicalOperator(LogicalOperator::And), + TokenTypes::Or => WhereStackElement::LogicalOperator(LogicalOperator::Or), + TokenTypes::Not => WhereStackElement::LogicalOperator(LogicalOperator::Not), + TokenTypes::LeftParen => WhereStackElement::Parentheses(Parentheses::Left), + TokenTypes::RightParen => WhereStackElement::Parentheses(Parentheses::Right), + _ => unreachable!("Invalid token type for where stack element"), + } +} \ No newline at end of file diff --git a/src/interpreter/ast/mod.rs b/src/interpreter/ast/mod.rs index 57cf16d..27693fc 100644 --- a/src/interpreter/ast/mod.rs +++ b/src/interpreter/ast/mod.rs @@ -200,11 +200,21 @@ pub enum WhereStackElement { Parentheses(Parentheses), } +#[derive(Debug, PartialEq)] pub enum WhereStackOperators { LogicalOperator(LogicalOperator), Parentheses(Parentheses), } +impl WhereStackOperators { + pub fn into_where_stack_element(self) -> WhereStackElement { + match self { + WhereStackOperators::LogicalOperator(logical_operator) => WhereStackElement::LogicalOperator(logical_operator), + WhereStackOperators::Parentheses(parentheses) => WhereStackElement::Parentheses(parentheses), + } + } +} + #[derive(Debug, PartialEq)] pub enum LogicalOperator { Not, diff --git a/src/interpreter/ast/update_statement.rs b/src/interpreter/ast/update_statement.rs index 1f1b3e6..157b925 100644 --- a/src/interpreter/ast/update_statement.rs +++ b/src/interpreter/ast/update_statement.rs @@ -3,7 +3,7 @@ use crate::interpreter::ast::{ helpers::common::{expect_token_type, token_to_value, get_table_name}, helpers::{order_by_clause::get_order_by, limit_clause::get_limit}, }; -use crate::interpreter::ast::helpers::where_stack::get_where_clause; +use crate::interpreter::ast::helpers::where_clause::get_where_clause; use crate::interpreter::tokenizer::token::TokenTypes; pub fn build(parser: &mut Parser) -> Result {