diff --git a/src/cli/ast/delete_statement.rs b/src/cli/ast/delete_statement.rs index 185fde9..5f80b8c 100644 --- a/src/cli/ast/delete_statement.rs +++ b/src/cli/ast/delete_statement.rs @@ -34,7 +34,8 @@ mod tests { use crate::cli::ast::OrderByDirection; use crate::cli::ast::LimitClause; use crate::cli::ast::Operator; - use crate::cli::ast::WhereClause; + use crate::cli::ast::WhereStackElement; + use crate::cli::ast::WhereCondition; use crate::db::table::Value; #[test] @@ -86,15 +87,19 @@ mod tests { let statement = result.unwrap(); let expected = SqlStatement::DeleteStatement(DeleteStatement { table_name: "users".to_string(), - where_clause: Some(WhereClause { - column: "id".to_string(), - operator: Operator::Equals, - value: Value::Integer(1), - }), - order_by_clause: Some(vec![OrderByClause { - column: "id".to_string(), - direction: OrderByDirection::Asc, - }]), + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }) + ]), + order_by_clause: Some(vec![ + OrderByClause { + column: "id".to_string(), + direction: OrderByDirection::Asc, + } + ]), limit_clause: Some(LimitClause { limit: Value::Integer(10), offset: Some(Value::Integer(5)), diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index 790e4f1..0df3090 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -1,42 +1,196 @@ -use crate::cli::ast::{parser::Parser, WhereClause, Operator, helpers::common::{expect_token_type, token_to_value}}; +use crate::cli::ast::{ + helpers::common::{expect_token_type, token_to_value}, + parser::Parser, + Operator, WhereCondition, WhereStackElement, LogicalOperator, Parentheses, WhereStackOperators, +}; use crate::cli::tokenizer::token::TokenTypes; -pub fn get_where_clause(parser: &mut Parser) -> Result, String> { +// The WhereStack is a the method that is used to store the order of operations with Reverse Polish Notation. +// This is built from the infix expression of the where clause. Using the shunting yard algorithm. Thanks Djikstra! +// Operator precedence is given as '()' > 'NOT' > 'AND' > 'OR' +// This is currently represented as stack of LogicalOperators, WhereConditions. +// WhereConditions are currently represented as 'column operator value' +// This will later be expanded to replace the WhereConditions with a generalized evaluation function. + +// We also validate the order using an enum of the current next expected token types. +#[derive(PartialEq, Debug)] +enum WhereClauseExpectedNextToken { + ConditionLeftParenNot, + LogicalOperatorRightParen, +} + +pub fn get_where_clause(parser: &mut Parser) -> Result>, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); } parser.advance()?; + let mut where_stack: Vec = vec![]; + let mut operator_stack: Vec = vec![]; + let mut expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; - let token = parser.current_token()?; - expect_token_type(parser, TokenTypes::Identifier)?; - let column = token.value.to_string(); - parser.advance()?; + loop { + let where_condition = get_where_condition(parser)?; + match where_condition { + Some(where_stack_element) => { + match where_stack_element { + WhereStackElement::Condition(where_condition) => { + if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } + expected_next_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; + where_stack.push(WhereStackElement::Condition(where_condition)); + }, + WhereStackElement::Parentheses(parentheses) => { + match parentheses { + Parentheses::Left => { + if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } + operator_stack.push(WhereStackOperators::Parentheses(parentheses)); + }, + Parentheses::Right => { + if expected_next_token != WhereClauseExpectedNextToken::LogicalOperatorRightParen { + return Err(parser.format_error_nearby()); + } + expected_next_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; + loop { + let current_operator = operator_stack.pop(); + if let Some (current_operator) = current_operator { + match current_operator { + WhereStackOperators::Parentheses(Parentheses::Left) => { + break; + }, + WhereStackOperators::LogicalOperator(logical_operator) => { + where_stack.push(WhereStackElement::LogicalOperator(logical_operator)); + }, + WhereStackOperators::Parentheses(Parentheses::Right) => { + return Err("Mismatched parentheses found.".to_string()); + }, + } + } + else { + return Err("Mismatched parentheses found.".to_string()); + } + } + }, + } + }, + WhereStackElement::LogicalOperator(logical_operator) => { + match logical_operator { + LogicalOperator::Not => { + if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } + expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + } + _ => { + if expected_next_token != WhereClauseExpectedNextToken::LogicalOperatorRightParen { + return Err(parser.format_error_nearby()); + } + expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + } + } + loop { + let current_operator = if let Some(operator) = operator_stack.pop() { + operator + } else { + operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); + break; + }; + match current_operator { + WhereStackOperators::LogicalOperator(current_operator) => { + if logical_operator.is_greater_precedence(¤t_operator) { + operator_stack.push(WhereStackOperators::LogicalOperator(current_operator)); + operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); + break; + } + else { + where_stack.push(WhereStackElement::LogicalOperator(current_operator)); + } + }, + _ => { + operator_stack.push(current_operator); + operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); + break; + }, + } + } + } + } + } + None => break + } + } + while let Some(operator) = operator_stack.pop() { + match operator { + WhereStackOperators::LogicalOperator(logical_operator) => { + where_stack.push(WhereStackElement::LogicalOperator(logical_operator)); + }, + _ => return Err("Mismatched parentheses found.".to_string()), + } + } + Ok(Some(where_stack)) +} + +fn get_where_condition(parser: &mut Parser) -> Result, String> { let token = parser.current_token()?; - let operator = match token.token_type { - TokenTypes::Equals => Operator::Equals, - TokenTypes::NotEquals => Operator::NotEquals, - TokenTypes::LessThan => Operator::LessThan, - TokenTypes::LessEquals => Operator::LessEquals, - TokenTypes::GreaterThan => Operator::GreaterThan, - TokenTypes::GreaterEquals => Operator::GreaterEquals, - _ => return Err(parser.format_error()), - }; - parser.advance()?; + match token.token_type { + TokenTypes::And => { + parser.advance()?; + return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::And))) + }, + TokenTypes::Or => { + parser.advance()?; + return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::Or))) + }, + TokenTypes::Not => { + parser.advance()?; + return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::Not))) + }, + TokenTypes::LeftParen => { + parser.advance()?; + return Ok(Some(WhereStackElement::Parentheses(Parentheses::Left))) + }, + TokenTypes::RightParen => { + parser.advance()?; + return Ok(Some(WhereStackElement::Parentheses(Parentheses::Right))) + }, + TokenTypes::Identifier => { + let column = token.value.to_string(); + parser.advance()?; - let value = token_to_value(parser)?; - parser.advance()?; + let token = parser.current_token()?; + let operator = match token.token_type { + TokenTypes::Equals => Operator::Equals, + TokenTypes::NotEquals => Operator::NotEquals, + TokenTypes::LessThan => Operator::LessThan, + TokenTypes::LessEquals => Operator::LessEquals, + TokenTypes::GreaterThan => Operator::GreaterThan, + TokenTypes::GreaterEquals => Operator::GreaterEquals, + _ => return Err(parser.format_error()), + }; + parser.advance()?; + + let value = token_to_value(parser)?; + parser.advance()?; - return Ok(Some(WhereClause { - column: column, - operator: operator, - value: value, - })); + return Ok(Some(WhereStackElement::Condition( + WhereCondition { + column, + operator, + value, + }) + )); + } + _ => return Ok(None), + } } #[cfg(test)] mod tests { use super::*; + use crate::cli::ast::LogicalOperator; use crate::cli::tokenizer::scanner::Token; use crate::db::table::Value; @@ -63,11 +217,11 @@ mod tests { let result = get_where_clause(&mut parser); assert!(result.is_ok()); let where_clause = result.unwrap(); - let expected = Some(WhereClause { + let expected = Some(vec![WhereStackElement::Condition(WhereCondition { column: "id".to_string(), operator: Operator::Equals, value: Value::Integer(1), - }); + })]); assert_eq!(expected, where_clause); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::Limit); } @@ -85,4 +239,338 @@ mod tests { assert!(result.unwrap().is_none()); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::Select); } + + #[test] + fn where_clause_with_two_conditions_is_generated_correctly() { + // WHERE id = 1 AND name = "John"; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::LogicalOperator(LogicalOperator::And), + ]); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_not_logical_operators_is_generated_correctly() { + // WHERE NOT id = 1 AND name = "John" OR age > 20; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "age"), + token(TokenTypes::GreaterThan, ">"), + token(TokenTypes::IntLiteral, "20"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::LogicalOperator(LogicalOperator::And), + WhereStackElement::Condition(WhereCondition { + column: "age".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(20), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + ]); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_different_precedence_is_generated_correctly() { + // WHERE id = 1 OR NOT name = "John" AND NOT age > 20; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::Identifier, "age"), + token(TokenTypes::GreaterThan, ">"), + token(TokenTypes::IntLiteral, "20"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::Condition(WhereCondition { + column: "age".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(20), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::LogicalOperator(LogicalOperator::And), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + ]); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_parentheses_is_generated_correctly() { + // WHERE (id = 1 OR name = "John") AND NOT (age > 20 OR active = 0); + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "age"), + token(TokenTypes::GreaterThan, ">"), + token(TokenTypes::IntLiteral, "20"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "active"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "0"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + WhereStackElement::Condition(WhereCondition { + column: "age".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(20), + }), + WhereStackElement::Condition(WhereCondition { + column: "active".to_string(), + operator: Operator::Equals, + value: Value::Integer(0), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::LogicalOperator(LogicalOperator::And), + ]); + println!("{:?}", result); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_nested_parentheses_and_logical_operators_is_generated_correctly() { + // WHERE (id = 1 OR NOT (name = "John" AND age > 20)); + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Identifier, "age"), + token(TokenTypes::GreaterThan, ">"), + token(TokenTypes::IntLiteral, "20"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::Condition(WhereCondition { + column: "age".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(20), + }), + WhereStackElement::LogicalOperator(LogicalOperator::And), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + ]); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_invalid_parentheses_is_generated_correctly() { + // WHERE (id = 1 OR name = "John"; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Mismatched parentheses found."); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_invalid_right_paren_is_generated_correctly() { + // WHERE (id = 1 OR name = "John")); + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Mismatched parentheses found."); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_valid_not_logical_operator_is_generated_correctly() { + // WHERE NOT id = 1; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + assert_eq!(where_clause, Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + ])); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_invalid_not_logical_operator_is_generated_correctly() { + // WHERE NOT AND id = 1; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Error near line 1, column 0"); + } } diff --git a/src/cli/ast/mod.rs b/src/cli/ast/mod.rs index 8c1d5ef..bdfa76e 100644 --- a/src/cli/ast/mod.rs +++ b/src/cli/ast/mod.rs @@ -37,7 +37,7 @@ pub struct InsertIntoStatement { pub struct SelectStatement { pub table_name: String, pub columns: SelectStatementColumns, - pub where_clause: Option, + pub where_clause: Option>, pub order_by_clause: Option>, pub limit_clause: Option, } @@ -45,7 +45,7 @@ pub struct SelectStatement { #[derive(Debug, PartialEq)] pub struct DeleteStatement { pub table_name: String, - pub where_clause: Option, + pub where_clause: Option>, pub order_by_clause: Option>, pub limit_clause: Option, } @@ -54,7 +54,7 @@ pub struct DeleteStatement { pub struct UpdateStatement { pub table_name: String, pub update_values: Vec, - pub where_clause: Option, + pub where_clause: Option>, } #[derive(Debug, PartialEq)] @@ -89,12 +89,52 @@ pub enum Operator { } #[derive(Debug, PartialEq)] -pub struct WhereClause { +pub struct WhereCondition { pub column: String, pub operator: Operator, pub value: Value, } +#[derive(Debug, PartialEq)] +pub enum WhereStackElement { + Condition(WhereCondition), + LogicalOperator(LogicalOperator), + Parentheses(Parentheses), +} + +pub enum WhereStackOperators { + LogicalOperator(LogicalOperator), + Parentheses(Parentheses), +} + +#[derive(Debug, PartialEq)] +pub enum LogicalOperator { + Not, + And, + Or, +} + +impl LogicalOperator { + pub fn is_greater_precedence(&self, other: &LogicalOperator) -> bool { + match (self, other) { + (LogicalOperator::Not, LogicalOperator::Not) => false, + (LogicalOperator::Not, _) => true, + (LogicalOperator::And, LogicalOperator::Not) => false, + (LogicalOperator::And, LogicalOperator::And) => false, + (LogicalOperator::And, LogicalOperator::Or) => true, + (LogicalOperator::Or, LogicalOperator::Not) => false, + (LogicalOperator::Or, LogicalOperator::And) => false, + (LogicalOperator::Or, LogicalOperator::Or) => false, + } + } +} + +#[derive(Debug, PartialEq)] +pub enum Parentheses { + Left, + Right, +} + #[derive(Debug, PartialEq)] pub enum OrderByDirection { Asc, diff --git a/src/cli/ast/parser.rs b/src/cli/ast/parser.rs index 4ebc884..f84163c 100644 --- a/src/cli/ast/parser.rs +++ b/src/cli/ast/parser.rs @@ -55,6 +55,18 @@ impl<'a> Parser<'a> { } } + pub fn format_error_nearby(&self) -> String { + if self.current < self.tokens.len() { + let token = &self.tokens[self.current]; + return format!( + "Error near line {:?}, column {:?}", + token.line_num, token.col_num + ); + } else { + return "Error at end of input.".to_string(); + } + } + pub fn next_statement(&mut self, builder: &dyn StatementBuilder) -> Option> { match self.current_token() { Ok(token) => match token.token_type { diff --git a/src/cli/ast/select_statement.rs b/src/cli/ast/select_statement.rs index f034776..72fa919 100644 --- a/src/cli/ast/select_statement.rs +++ b/src/cli/ast/select_statement.rs @@ -1,6 +1,6 @@ use crate::{cli::{ ast::{ - parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereClause, + parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereStackElement, helpers::{ common::{expect_token_type, tokens_to_identifier_list, get_table_name}, order_by_clause::get_order_by, where_clause::get_where_clause, limit_clause::get_limit @@ -14,7 +14,7 @@ pub fn build(parser: &mut Parser) -> Result { let columns = get_columns(parser)?; let table_name = get_table_name(parser)?; parser.advance()?; - let where_clause: Option = get_where_clause(parser)?; + let where_clause: Option> = get_where_clause(parser)?; let order_by_clause = get_order_by(parser)?; let limit_clause = get_limit(parser)?; @@ -49,6 +49,8 @@ mod tests { use crate::cli::ast::OrderByClause; use crate::cli::ast::OrderByDirection; use crate::cli::ast::LimitClause; + use crate::cli::ast::WhereStackElement; + use crate::cli::ast::WhereCondition; use crate::cli::ast::test_utils::token; #[test] @@ -163,11 +165,13 @@ mod tests { columns: SelectStatementColumns::Specific(vec![ "id".to_string(), ]), - where_clause: Some(WhereClause { - column: "id".to_string(), - operator: Operator::Equals, - value: Value::Integer(1), - }), + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + ]), order_by_clause: Some(vec![ OrderByClause { column: "id".to_string(), diff --git a/src/cli/ast/update_statement.rs b/src/cli/ast/update_statement.rs index 7bbdc05..d9bbfbf 100644 --- a/src/cli/ast/update_statement.rs +++ b/src/cli/ast/update_statement.rs @@ -60,7 +60,8 @@ mod tests { use super::*; use crate::db::table::Value; use crate::cli::ast::Operator; - use crate::cli::ast::WhereClause; + use crate::cli::ast::WhereStackElement; + use crate::cli::ast::WhereCondition; use crate::cli::ast::test_utils::token; #[test] @@ -116,11 +117,13 @@ mod tests { column: "column".to_string(), value: Value::Integer(1), }], - where_clause: Some(WhereClause { - column: "id".to_string(), - operator: Operator::GreaterThan, - value: Value::Integer(2), - }), + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(2), + }), + ]), }); assert_eq!(statement, expected); } @@ -162,12 +165,14 @@ mod tests { value: Value::Text("False".to_string()), }, ], - where_clause: Some(WhereClause { - column: "id".to_string(), - operator: Operator::Equals, - value: Value::Integer(3), - }), - }); + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(3), + }), + ]), + }); assert_eq!(statement, expected); } } \ No newline at end of file diff --git a/src/db/table/select/mod.rs b/src/db/table/select/mod.rs index e95ab92..31d6367 100644 --- a/src/db/table/select/mod.rs +++ b/src/db/table/select/mod.rs @@ -2,8 +2,7 @@ pub mod where_clause; pub mod limit_clause; pub mod order_by_clause; use crate::db::table::{Table, Value}; -use crate::cli::ast::SelectStatement; -use crate::cli::ast::SelectStatementColumns; +use crate::cli::ast::{SelectStatement, WhereStackElement, SelectStatementColumns}; use crate::db::table::common::validate_and_clone_row; @@ -23,7 +22,13 @@ pub fn select(table: &Table, statement: SelectStatement) -> Result Result>, String> { let mut rows: Vec> = vec![]; - if let Some(where_clause) = &statement.where_clause { + if let Some(where_stack) = &statement.where_clause { + // This will need to be updated once we have multiple conditions working properly + let where_clause = match where_stack.first() { + Some(WhereStackElement::Condition(where_clause)) => where_clause, + _ => return Err(format!("Found nothing when expected edge")), + }; + if !table.has_column(&where_clause.column) { return Err(format!("Column {} does not exist in table {}", where_clause.column, table.name)); @@ -61,7 +66,9 @@ pub fn get_columns_from_row(table: &Table, row: &Vec, selected_columns: & mod tests { use super::*; use crate::db::table::{Table, Value, DataType, ColumnDefinition}; - use crate::cli::ast::{SelectStatementColumns, WhereClause, LimitClause, OrderByClause, OrderByDirection, Operator}; + use crate::cli::ast::{SelectStatementColumns, LimitClause, OrderByClause, OrderByDirection, Operator}; + use crate::cli::ast::WhereStackElement; + use crate::cli::ast::WhereCondition; fn default_table() -> Table { Table { @@ -129,11 +136,13 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::All, - where_clause: Some(WhereClause { - column: "name".to_string(), - operator: Operator::Equals, - value: Value::Text("John".to_string()), - }), + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + ]), order_by_clause: None, limit_clause: None, }; @@ -151,11 +160,13 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::Specific(vec!["name".to_string(), "age".to_string()]), - where_clause: Some(WhereClause { - column: "money".to_string(), - operator: Operator::Equals, - value: Value::Real(1000.0), - }), + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "money".to_string(), + operator: Operator::Equals, + value: Value::Real(1000.0), + }), + ]), order_by_clause: None, limit_clause: None, }; @@ -194,11 +205,13 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::All, - where_clause: Some(WhereClause { - column: "column_not_included".to_string(), - operator: Operator::Equals, - value: Value::Text("John".to_string()), - }), + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "column_not_included".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + ]), order_by_clause: None, limit_clause: None, }; diff --git a/src/db/table/select/where_clause.rs b/src/db/table/select/where_clause.rs index 2a7dfa9..3bd1156 100644 --- a/src/db/table/select/where_clause.rs +++ b/src/db/table/select/where_clause.rs @@ -1,7 +1,7 @@ -use crate::cli::ast::{Operator, WhereClause}; +use crate::cli::ast::{Operator, WhereCondition}; use crate::db::table::{Table, Value, DataType}; -pub fn matches_where_clause(table: &Table, row: &Vec, where_clause: &WhereClause) -> bool { +pub fn matches_where_clause(table: &Table, row: &Vec, where_clause: &WhereCondition) -> bool { let column_value = table.get_column_from_row(row, &where_clause.column); if column_value.get_type() != where_clause.value.get_type() { return false; @@ -58,7 +58,7 @@ mod tests { }, ]); let row = vec![Value::Integer(1)]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; assert!(matches_where_clause(&table, &row, &where_clause)); } @@ -68,7 +68,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Integer(2)]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -82,7 +82,7 @@ mod tests { }, ]); let row = vec![Value::Integer(1)]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::Equals,value:Value::Text("Fletcher".to_string())}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::Equals,value:Value::Text("Fletcher".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -92,15 +92,15 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Integer(10)]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::GreaterThan,value:Value::Integer(0)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::GreaterThan,value:Value::Integer(0)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(0)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(0)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::LessThan,value:Value::Integer(20)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::LessThan,value:Value::Integer(20)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::LessEquals,value:Value::Integer(20)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::LessEquals,value:Value::Integer(20)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::NotEquals,value:Value::Integer(10)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::NotEquals,value:Value::Integer(10)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -110,17 +110,17 @@ mod tests { ColumnDefinition {name:"name".to_string(),data_type:DataType::Text, constraints: vec![] }, ]); let row = vec![Value::Text("lop".to_string())]; - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::GreaterEquals,value:Value::Text("abc".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::GreaterEquals,value:Value::Text("abc".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::LessEquals,value:Value::Text("lop".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::LessEquals,value:Value::Text("lop".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::GreaterThan,value:Value::Text("xyz".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::GreaterThan,value:Value::Text("xyz".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::LessThan,value:Value::Text("abc".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::LessThan,value:Value::Text("abc".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::NotEquals,value:Value::Text("abc".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::NotEquals,value:Value::Text("abc".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::Equals,value:Value::Text("lop".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::Equals,value:Value::Text("lop".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); } @@ -130,7 +130,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Null]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(1)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(1)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -140,7 +140,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Blob, constraints: vec![] }, ]); let row = vec![Value::Blob(vec![1, 2, 3])]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Blob(vec![1, 2, 3])}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Blob(vec![1, 2, 3])}; assert!(!matches_where_clause(&table, &row, &where_clause)); } } \ No newline at end of file