From 780c4f733e941c85a04b69becf57982c892a269e Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Wed, 3 Sep 2025 15:36:19 -0400 Subject: [PATCH 1/9] Swap out WhereClause for Where Tree structure instead --- src/cli/ast/delete_statement.rs | 18 +++++++--- src/cli/ast/helpers/where_clause.rs | 30 ++++++++++------ src/cli/ast/mod.rs | 45 +++++++++++++++++++++--- src/cli/ast/select_statement.rs | 20 +++++++---- src/cli/ast/update_statement.rs | 29 +++++++++++----- src/db/table/select/mod.rs | 54 +++++++++++++++++++++-------- src/db/table/select/where_clause.rs | 36 +++++++++---------- 7 files changed, 167 insertions(+), 65 deletions(-) diff --git a/src/cli/ast/delete_statement.rs b/src/cli/ast/delete_statement.rs index 185fde9..2ab167b 100644 --- a/src/cli/ast/delete_statement.rs +++ b/src/cli/ast/delete_statement.rs @@ -34,7 +34,10 @@ mod tests { use crate::cli::ast::OrderByDirection; use crate::cli::ast::LimitClause; use crate::cli::ast::Operator; - use crate::cli::ast::WhereClause; + use crate::cli::ast::WhereTreeNode; + use crate::cli::ast::WhereTreeElement; + use crate::cli::ast::WhereTreeEdge; + use crate::cli::ast::LogicalOperator; use crate::db::table::Value; #[test] @@ -86,10 +89,15 @@ mod tests { let statement = result.unwrap(); let expected = SqlStatement::DeleteStatement(DeleteStatement { table_name: "users".to_string(), - where_clause: Some(WhereClause { - column: "id".to_string(), - operator: Operator::Equals, - value: Value::Integer(1), + where_clause: Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, }), order_by_clause: Some(vec![OrderByClause { column: "id".to_string(), diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index 790e4f1..af4336a 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -1,7 +1,7 @@ -use crate::cli::ast::{parser::Parser, WhereClause, Operator, helpers::common::{expect_token_type, token_to_value}}; +use crate::cli::ast::{parser::Parser, WhereTreeNode, WhereTreeEdge, WhereTreeElement, Operator, LogicalOperator, helpers::common::{expect_token_type, token_to_value}}; use crate::cli::tokenizer::token::TokenTypes; -pub fn get_where_clause(parser: &mut Parser) -> Result, String> { +pub fn get_where_clause(parser: &mut Parser) -> Result, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); } @@ -27,10 +27,15 @@ pub fn get_where_clause(parser: &mut Parser) -> Result, Stri let value = token_to_value(parser)?; parser.advance()?; - return Ok(Some(WhereClause { - column: column, - operator: operator, - value: value, + return Ok(Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: column, + operator: operator, + value: value, + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, })); } @@ -63,10 +68,15 @@ mod tests { let result = get_where_clause(&mut parser); assert!(result.is_ok()); let where_clause = result.unwrap(); - let expected = Some(WhereClause { - column: "id".to_string(), - operator: Operator::Equals, - value: Value::Integer(1), + let expected = Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, }); assert_eq!(expected, where_clause); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::Limit); diff --git a/src/cli/ast/mod.rs b/src/cli/ast/mod.rs index 8c1d5ef..a093b41 100644 --- a/src/cli/ast/mod.rs +++ b/src/cli/ast/mod.rs @@ -37,7 +37,7 @@ pub struct InsertIntoStatement { pub struct SelectStatement { pub table_name: String, pub columns: SelectStatementColumns, - pub where_clause: Option, + pub where_clause: Option, pub order_by_clause: Option>, pub limit_clause: Option, } @@ -45,7 +45,7 @@ pub struct SelectStatement { #[derive(Debug, PartialEq)] pub struct DeleteStatement { pub table_name: String, - pub where_clause: Option, + pub where_clause: Option, pub order_by_clause: Option>, pub limit_clause: Option, } @@ -54,7 +54,7 @@ pub struct DeleteStatement { pub struct UpdateStatement { pub table_name: String, pub update_values: Vec, - pub where_clause: Option, + pub where_clause: Option, } #[derive(Debug, PartialEq)] @@ -88,13 +88,50 @@ pub enum Operator { GreaterEquals, } +// The default truthiness of a Null Node is False #[derive(Debug, PartialEq)] -pub struct WhereClause { +pub struct WhereTreeEdge { pub column: String, pub operator: Operator, pub value: Value, } +#[derive(Debug, PartialEq)] +pub struct WhereTreeNode { + pub left: Box>, + pub right: Box>, + pub operator: LogicalOperator, + pub negation: bool, +} + +#[derive(Debug, PartialEq)] +pub enum WhereTreeElement { + Edge(WhereTreeEdge), + _Node(WhereTreeNode), +} + +impl WhereTreeElement { + pub fn get_clause(&self) -> Result<&WhereTreeEdge, String> { + match self { + WhereTreeElement::Edge(edge) => Ok(edge), + _ => Err(format!("Found node when expected edge")), + } + } + + pub fn _get_node(&self) -> Result<&WhereTreeNode, String> { + match self { + WhereTreeElement::_Node(node) => Ok(node), + _ => Err(format!("Found edge when expected node")), + } + } +} + +#[derive(Debug, PartialEq)] +pub enum LogicalOperator { + _And, + Or, +} + #[derive(Debug, PartialEq)] pub enum OrderByDirection { Asc, diff --git a/src/cli/ast/select_statement.rs b/src/cli/ast/select_statement.rs index f034776..afc2086 100644 --- a/src/cli/ast/select_statement.rs +++ b/src/cli/ast/select_statement.rs @@ -1,6 +1,6 @@ use crate::{cli::{ ast::{ - parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereClause, + parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereTreeNode, helpers::{ common::{expect_token_type, tokens_to_identifier_list, get_table_name}, order_by_clause::get_order_by, where_clause::get_where_clause, limit_clause::get_limit @@ -14,7 +14,7 @@ pub fn build(parser: &mut Parser) -> Result { let columns = get_columns(parser)?; let table_name = get_table_name(parser)?; parser.advance()?; - let where_clause: Option = get_where_clause(parser)?; + let where_clause: Option = get_where_clause(parser)?; let order_by_clause = get_order_by(parser)?; let limit_clause = get_limit(parser)?; @@ -49,6 +49,9 @@ mod tests { use crate::cli::ast::OrderByClause; use crate::cli::ast::OrderByDirection; use crate::cli::ast::LimitClause; + use crate::cli::ast::WhereTreeElement; + use crate::cli::ast::WhereTreeEdge; + use crate::cli::ast::LogicalOperator; use crate::cli::ast::test_utils::token; #[test] @@ -163,10 +166,15 @@ mod tests { columns: SelectStatementColumns::Specific(vec![ "id".to_string(), ]), - where_clause: Some(WhereClause { - column: "id".to_string(), - operator: Operator::Equals, - value: Value::Integer(1), + where_clause: Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, }), order_by_clause: Some(vec![ OrderByClause { diff --git a/src/cli/ast/update_statement.rs b/src/cli/ast/update_statement.rs index 7bbdc05..3c1f48a 100644 --- a/src/cli/ast/update_statement.rs +++ b/src/cli/ast/update_statement.rs @@ -60,7 +60,10 @@ mod tests { use super::*; use crate::db::table::Value; use crate::cli::ast::Operator; - use crate::cli::ast::WhereClause; + use crate::cli::ast::WhereTreeNode; + use crate::cli::ast::WhereTreeElement; + use crate::cli::ast::WhereTreeEdge; + use crate::cli::ast::LogicalOperator; use crate::cli::ast::test_utils::token; #[test] @@ -116,10 +119,15 @@ mod tests { column: "column".to_string(), value: Value::Integer(1), }], - where_clause: Some(WhereClause { + where_clause: Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { column: "id".to_string(), - operator: Operator::GreaterThan, - value: Value::Integer(2), + operator: Operator::GreaterThan, + value: Value::Integer(2), + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, }), }); assert_eq!(statement, expected); @@ -162,10 +170,15 @@ mod tests { value: Value::Text("False".to_string()), }, ], - where_clause: Some(WhereClause { - column: "id".to_string(), - operator: Operator::Equals, - value: Value::Integer(3), + where_clause: Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(3), + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, }), }); assert_eq!(statement, expected); diff --git a/src/db/table/select/mod.rs b/src/db/table/select/mod.rs index e95ab92..9c1e1c6 100644 --- a/src/db/table/select/mod.rs +++ b/src/db/table/select/mod.rs @@ -23,7 +23,14 @@ pub fn select(table: &Table, statement: SelectStatement) -> Result Result>, String> { let mut rows: Vec> = vec![]; - if let Some(where_clause) = &statement.where_clause { + if let Some(where_tree_node) = &statement.where_clause { + // This will need to be updated once we have multiple conditions working properly + let where_tree_element = where_tree_node.left.as_ref(); + let where_clause = match where_tree_element { + Some(where_tree_element) => where_tree_element.get_clause()?, + _ => return Err(format!("Found nothing when expected edge")), + }; + if !table.has_column(&where_clause.column) { return Err(format!("Column {} does not exist in table {}", where_clause.column, table.name)); @@ -61,7 +68,11 @@ pub fn get_columns_from_row(table: &Table, row: &Vec, selected_columns: & mod tests { use super::*; use crate::db::table::{Table, Value, DataType, ColumnDefinition}; - use crate::cli::ast::{SelectStatementColumns, WhereClause, LimitClause, OrderByClause, OrderByDirection, Operator}; + use crate::cli::ast::{SelectStatementColumns, LimitClause, OrderByClause, OrderByDirection, Operator}; + use crate::cli::ast::WhereTreeNode; + use crate::cli::ast::WhereTreeElement; + use crate::cli::ast::WhereTreeEdge; + use crate::cli::ast::LogicalOperator; fn default_table() -> Table { Table { @@ -129,10 +140,15 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::All, - where_clause: Some(WhereClause { - column: "name".to_string(), - operator: Operator::Equals, - value: Value::Text("John".to_string()), + where_clause: Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, }), order_by_clause: None, limit_clause: None, @@ -151,10 +167,15 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::Specific(vec!["name".to_string(), "age".to_string()]), - where_clause: Some(WhereClause { - column: "money".to_string(), - operator: Operator::Equals, - value: Value::Real(1000.0), + where_clause: Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "money".to_string(), + operator: Operator::Equals, + value: Value::Real(1000.0), + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, }), order_by_clause: None, limit_clause: None, @@ -194,10 +215,15 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::All, - where_clause: Some(WhereClause { - column: "column_not_included".to_string(), - operator: Operator::Equals, - value: Value::Text("John".to_string()), + where_clause: Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "column_not_included".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }))), + right: Box::new(None), + operator: LogicalOperator::Or, + negation: false, }), order_by_clause: None, limit_clause: None, diff --git a/src/db/table/select/where_clause.rs b/src/db/table/select/where_clause.rs index 2a7dfa9..b7bd1f0 100644 --- a/src/db/table/select/where_clause.rs +++ b/src/db/table/select/where_clause.rs @@ -1,7 +1,7 @@ -use crate::cli::ast::{Operator, WhereClause}; +use crate::cli::ast::{Operator, WhereTreeEdge}; use crate::db::table::{Table, Value, DataType}; -pub fn matches_where_clause(table: &Table, row: &Vec, where_clause: &WhereClause) -> bool { +pub fn matches_where_clause(table: &Table, row: &Vec, where_clause: &WhereTreeEdge) -> bool { let column_value = table.get_column_from_row(row, &where_clause.column); if column_value.get_type() != where_clause.value.get_type() { return false; @@ -58,7 +58,7 @@ mod tests { }, ]); let row = vec![Value::Integer(1)]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; assert!(matches_where_clause(&table, &row, &where_clause)); } @@ -68,7 +68,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Integer(2)]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -82,7 +82,7 @@ mod tests { }, ]); let row = vec![Value::Integer(1)]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::Equals,value:Value::Text("Fletcher".to_string())}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::Equals,value:Value::Text("Fletcher".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -92,15 +92,15 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Integer(10)]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::GreaterThan,value:Value::Integer(0)}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::GreaterThan,value:Value::Integer(0)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(0)}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(0)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::LessThan,value:Value::Integer(20)}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::LessThan,value:Value::Integer(20)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::LessEquals,value:Value::Integer(20)}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::LessEquals,value:Value::Integer(20)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::NotEquals,value:Value::Integer(10)}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::NotEquals,value:Value::Integer(10)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -110,17 +110,17 @@ mod tests { ColumnDefinition {name:"name".to_string(),data_type:DataType::Text, constraints: vec![] }, ]); let row = vec![Value::Text("lop".to_string())]; - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::GreaterEquals,value:Value::Text("abc".to_string())}; + let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::GreaterEquals,value:Value::Text("abc".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::LessEquals,value:Value::Text("lop".to_string())}; + let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::LessEquals,value:Value::Text("lop".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::GreaterThan,value:Value::Text("xyz".to_string())}; + let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::GreaterThan,value:Value::Text("xyz".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::LessThan,value:Value::Text("abc".to_string())}; + let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::LessThan,value:Value::Text("abc".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::NotEquals,value:Value::Text("abc".to_string())}; + let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::NotEquals,value:Value::Text("abc".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereClause {column:"name".to_string(),operator:Operator::Equals,value:Value::Text("lop".to_string())}; + let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::Equals,value:Value::Text("lop".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); } @@ -130,7 +130,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Null]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(1)}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(1)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -140,7 +140,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Blob, constraints: vec![] }, ]); let row = vec![Value::Blob(vec![1, 2, 3])]; - let where_clause = WhereClause {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Blob(vec![1, 2, 3])}; + let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Blob(vec![1, 2, 3])}; assert!(!matches_where_clause(&table, &row, &where_clause)); } } \ No newline at end of file From 55193e6c24ed0a9c6b0730c789d5d26b01438632 Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Wed, 3 Sep 2025 16:13:35 -0400 Subject: [PATCH 2/9] Implement single level tree-structure --- src/cli/ast/delete_statement.rs | 3 +- src/cli/ast/helpers/where_clause.rs | 76 ++++++++++++++++++++++++----- src/cli/ast/mod.rs | 4 +- src/cli/ast/select_statement.rs | 3 +- src/cli/ast/update_statement.rs | 5 +- src/db/table/select/mod.rs | 7 ++- 6 files changed, 74 insertions(+), 24 deletions(-) diff --git a/src/cli/ast/delete_statement.rs b/src/cli/ast/delete_statement.rs index 2ab167b..c38f204 100644 --- a/src/cli/ast/delete_statement.rs +++ b/src/cli/ast/delete_statement.rs @@ -37,7 +37,6 @@ mod tests { use crate::cli::ast::WhereTreeNode; use crate::cli::ast::WhereTreeElement; use crate::cli::ast::WhereTreeEdge; - use crate::cli::ast::LogicalOperator; use crate::db::table::Value; #[test] @@ -96,7 +95,7 @@ mod tests { value: Value::Integer(1), }))), right: Box::new(None), - operator: LogicalOperator::Or, + operator: None, negation: false, }), order_by_clause: Some(vec![OrderByClause { diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index af4336a..7bb6e6d 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -7,6 +7,29 @@ pub fn get_where_clause(parser: &mut Parser) -> Result, St } parser.advance()?; + let left_where_tree_edge = Some(WhereTreeElement::Edge(get_where_clause_edge(parser)?)); + let operator = match parser.current_token()?.token_type { + TokenTypes::And => Some(LogicalOperator::And), + TokenTypes::Or => Some(LogicalOperator::Or), + _ => None, + }; + let right_where_tree_edge = match operator.is_some() { + true => { + parser.advance()?; + Some(WhereTreeElement::Edge(get_where_clause_edge(parser)?)) + }, + false => None, + }; + + return Ok(Some(WhereTreeNode { + left: Box::new(left_where_tree_edge), + right: Box::new(right_where_tree_edge), + operator: operator, + negation: false, + })); +} + +fn get_where_clause_edge(parser: &mut Parser) -> Result { let token = parser.current_token()?; expect_token_type(parser, TokenTypes::Identifier)?; let column = token.value.to_string(); @@ -27,16 +50,11 @@ pub fn get_where_clause(parser: &mut Parser) -> Result, St let value = token_to_value(parser)?; parser.advance()?; - return Ok(Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { - column: column, - operator: operator, - value: value, - }))), - right: Box::new(None), - operator: LogicalOperator::Or, - negation: false, - })); + return Ok(WhereTreeEdge { + column: column, + operator: operator, + value: value, + }); } #[cfg(test)] @@ -75,7 +93,7 @@ mod tests { value: Value::Integer(1), }))), right: Box::new(None), - operator: LogicalOperator::Or, + operator: None, negation: false, }); assert_eq!(expected, where_clause); @@ -95,4 +113,40 @@ mod tests { assert!(result.unwrap().is_none()); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::Select); } + + #[test] + fn where_clause_with_two_conditions_is_generated_correctly() { + // WHERE id = 1 AND name = "John"; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + let expected = Some(WhereTreeNode { + left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }))), + right: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }))), + operator: Some(LogicalOperator::And), + negation: false, + }); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } } diff --git a/src/cli/ast/mod.rs b/src/cli/ast/mod.rs index a093b41..56bf55d 100644 --- a/src/cli/ast/mod.rs +++ b/src/cli/ast/mod.rs @@ -100,7 +100,7 @@ pub struct WhereTreeEdge { pub struct WhereTreeNode { pub left: Box>, pub right: Box>, - pub operator: LogicalOperator, + pub operator: Option, pub negation: bool, } @@ -128,7 +128,7 @@ impl WhereTreeElement { #[derive(Debug, PartialEq)] pub enum LogicalOperator { - _And, + And, Or, } diff --git a/src/cli/ast/select_statement.rs b/src/cli/ast/select_statement.rs index afc2086..18c3095 100644 --- a/src/cli/ast/select_statement.rs +++ b/src/cli/ast/select_statement.rs @@ -51,7 +51,6 @@ mod tests { use crate::cli::ast::LimitClause; use crate::cli::ast::WhereTreeElement; use crate::cli::ast::WhereTreeEdge; - use crate::cli::ast::LogicalOperator; use crate::cli::ast::test_utils::token; #[test] @@ -173,7 +172,7 @@ mod tests { value: Value::Integer(1), }))), right: Box::new(None), - operator: LogicalOperator::Or, + operator: None, negation: false, }), order_by_clause: Some(vec![ diff --git a/src/cli/ast/update_statement.rs b/src/cli/ast/update_statement.rs index 3c1f48a..bd8059d 100644 --- a/src/cli/ast/update_statement.rs +++ b/src/cli/ast/update_statement.rs @@ -63,7 +63,6 @@ mod tests { use crate::cli::ast::WhereTreeNode; use crate::cli::ast::WhereTreeElement; use crate::cli::ast::WhereTreeEdge; - use crate::cli::ast::LogicalOperator; use crate::cli::ast::test_utils::token; #[test] @@ -126,7 +125,7 @@ mod tests { value: Value::Integer(2), }))), right: Box::new(None), - operator: LogicalOperator::Or, + operator: None, negation: false, }), }); @@ -177,7 +176,7 @@ mod tests { value: Value::Integer(3), }))), right: Box::new(None), - operator: LogicalOperator::Or, + operator: None, negation: false, }), }); diff --git a/src/db/table/select/mod.rs b/src/db/table/select/mod.rs index 9c1e1c6..2015f6b 100644 --- a/src/db/table/select/mod.rs +++ b/src/db/table/select/mod.rs @@ -72,7 +72,6 @@ mod tests { use crate::cli::ast::WhereTreeNode; use crate::cli::ast::WhereTreeElement; use crate::cli::ast::WhereTreeEdge; - use crate::cli::ast::LogicalOperator; fn default_table() -> Table { Table { @@ -147,7 +146,7 @@ mod tests { value: Value::Text("John".to_string()), }))), right: Box::new(None), - operator: LogicalOperator::Or, + operator: None, negation: false, }), order_by_clause: None, @@ -174,7 +173,7 @@ mod tests { value: Value::Real(1000.0), }))), right: Box::new(None), - operator: LogicalOperator::Or, + operator: None, negation: false, }), order_by_clause: None, @@ -222,7 +221,7 @@ mod tests { value: Value::Text("John".to_string()), }))), right: Box::new(None), - operator: LogicalOperator::Or, + operator: None, negation: false, }), order_by_clause: None, From a4c9599cee6b5ccf77ec0554ae21ab6579bae91e Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Wed, 3 Sep 2025 19:59:07 -0400 Subject: [PATCH 3/9] Refactor to use Reverse Polish notation instead --- src/cli/ast/delete_statement.rs | 26 ++++++----- src/cli/ast/helpers/where_clause.rs | 67 ++++++++++------------------- src/cli/ast/mod.rs | 43 ++++++------------ src/cli/ast/select_statement.rs | 19 ++++---- src/cli/ast/update_statement.rs | 31 ++++++------- src/db/table/select/mod.rs | 45 ++++++++----------- src/db/table/select/where_clause.rs | 36 ++++++++-------- 7 files changed, 104 insertions(+), 163 deletions(-) diff --git a/src/cli/ast/delete_statement.rs b/src/cli/ast/delete_statement.rs index c38f204..5f80b8c 100644 --- a/src/cli/ast/delete_statement.rs +++ b/src/cli/ast/delete_statement.rs @@ -34,9 +34,8 @@ mod tests { use crate::cli::ast::OrderByDirection; use crate::cli::ast::LimitClause; use crate::cli::ast::Operator; - use crate::cli::ast::WhereTreeNode; - use crate::cli::ast::WhereTreeElement; - use crate::cli::ast::WhereTreeEdge; + use crate::cli::ast::WhereStackElement; + use crate::cli::ast::WhereCondition; use crate::db::table::Value; #[test] @@ -88,20 +87,19 @@ mod tests { let statement = result.unwrap(); let expected = SqlStatement::DeleteStatement(DeleteStatement { table_name: "users".to_string(), - where_clause: Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { column: "id".to_string(), operator: Operator::Equals, value: Value::Integer(1), - }))), - right: Box::new(None), - operator: None, - negation: false, - }), - order_by_clause: Some(vec![OrderByClause { - column: "id".to_string(), - direction: OrderByDirection::Asc, - }]), + }) + ]), + order_by_clause: Some(vec![ + OrderByClause { + column: "id".to_string(), + direction: OrderByDirection::Asc, + } + ]), limit_clause: Some(LimitClause { limit: Value::Integer(10), offset: Some(Value::Integer(5)), diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index 7bb6e6d..4658491 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -1,35 +1,18 @@ -use crate::cli::ast::{parser::Parser, WhereTreeNode, WhereTreeEdge, WhereTreeElement, Operator, LogicalOperator, helpers::common::{expect_token_type, token_to_value}}; +use crate::cli::ast::{parser::Parser, WhereStackElement, WhereCondition, Operator, helpers::common::{expect_token_type, token_to_value}}; use crate::cli::tokenizer::token::TokenTypes; -pub fn get_where_clause(parser: &mut Parser) -> Result, String> { +pub fn get_where_clause(parser: &mut Parser) -> Result>, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); } parser.advance()?; - let left_where_tree_edge = Some(WhereTreeElement::Edge(get_where_clause_edge(parser)?)); - let operator = match parser.current_token()?.token_type { - TokenTypes::And => Some(LogicalOperator::And), - TokenTypes::Or => Some(LogicalOperator::Or), - _ => None, - }; - let right_where_tree_edge = match operator.is_some() { - true => { - parser.advance()?; - Some(WhereTreeElement::Edge(get_where_clause_edge(parser)?)) - }, - false => None, - }; + let where_condition = get_where_clause_edge(parser)?; - return Ok(Some(WhereTreeNode { - left: Box::new(left_where_tree_edge), - right: Box::new(right_where_tree_edge), - operator: operator, - negation: false, - })); + return Ok(Some(vec![WhereStackElement::Condition(where_condition)])); } -fn get_where_clause_edge(parser: &mut Parser) -> Result { +fn get_where_clause_edge(parser: &mut Parser) -> Result { let token = parser.current_token()?; expect_token_type(parser, TokenTypes::Identifier)?; let column = token.value.to_string(); @@ -50,7 +33,7 @@ fn get_where_clause_edge(parser: &mut Parser) -> Result { let value = token_to_value(parser)?; parser.advance()?; - return Ok(WhereTreeEdge { + return Ok(WhereCondition { column: column, operator: operator, value: value, @@ -62,6 +45,7 @@ mod tests { use super::*; use crate::cli::tokenizer::scanner::Token; use crate::db::table::Value; + use crate::cli::ast::LogicalOperator; fn token(tt: TokenTypes, val: &'static str) -> Token<'static> { Token { @@ -86,16 +70,11 @@ mod tests { let result = get_where_clause(&mut parser); assert!(result.is_ok()); let where_clause = result.unwrap(); - let expected = Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { - column: "id".to_string(), - operator: Operator::Equals, - value: Value::Integer(1), - }))), - right: Box::new(None), - operator: None, - negation: false, - }); + let expected = Some(vec![WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + })]); assert_eq!(expected, where_clause); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::Limit); } @@ -131,22 +110,22 @@ mod tests { let mut parser = Parser::new(tokens); let result = get_where_clause(&mut parser); assert!(result.is_ok()); - let where_clause = result.unwrap(); - let expected = Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + let _where_clause = result.unwrap(); + let _expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { column: "id".to_string(), operator: Operator::Equals, value: Value::Integer(1), - }))), - right: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + }), + WhereStackElement::Condition(WhereCondition { column: "name".to_string(), operator: Operator::Equals, value: Value::Text("John".to_string()), - }))), - operator: Some(LogicalOperator::And), - negation: false, - }); - assert_eq!(expected, where_clause); - assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + }), + WhereStackElement::_LogicalOperator(LogicalOperator::_And), + ]); + // TEST IS NOT WORKING YET DUE TO A CHANGE IN THE AST + // assert_eq!(expected, where_clause); + // assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); } } diff --git a/src/cli/ast/mod.rs b/src/cli/ast/mod.rs index 56bf55d..62f69d0 100644 --- a/src/cli/ast/mod.rs +++ b/src/cli/ast/mod.rs @@ -37,7 +37,7 @@ pub struct InsertIntoStatement { pub struct SelectStatement { pub table_name: String, pub columns: SelectStatementColumns, - pub where_clause: Option, + pub where_clause: Option>, pub order_by_clause: Option>, pub limit_clause: Option, } @@ -45,7 +45,7 @@ pub struct SelectStatement { #[derive(Debug, PartialEq)] pub struct DeleteStatement { pub table_name: String, - pub where_clause: Option, + pub where_clause: Option>, pub order_by_clause: Option>, pub limit_clause: Option, } @@ -54,7 +54,7 @@ pub struct DeleteStatement { pub struct UpdateStatement { pub table_name: String, pub update_values: Vec, - pub where_clause: Option, + pub where_clause: Option>, } #[derive(Debug, PartialEq)] @@ -88,48 +88,33 @@ pub enum Operator { GreaterEquals, } -// The default truthiness of a Null Node is False #[derive(Debug, PartialEq)] -pub struct WhereTreeEdge { +pub struct WhereCondition { pub column: String, pub operator: Operator, pub value: Value, } #[derive(Debug, PartialEq)] -pub struct WhereTreeNode { - pub left: Box>, - pub right: Box>, - pub operator: Option, - pub negation: bool, +pub enum WhereStackElement { + Condition(WhereCondition), + _LogicalOperator(LogicalOperator), } -#[derive(Debug, PartialEq)] -pub enum WhereTreeElement { - Edge(WhereTreeEdge), - _Node(WhereTreeNode), -} - -impl WhereTreeElement { - pub fn get_clause(&self) -> Result<&WhereTreeEdge, String> { - match self { - WhereTreeElement::Edge(edge) => Ok(edge), - _ => Err(format!("Found node when expected edge")), - } - } - - pub fn _get_node(&self) -> Result<&WhereTreeNode, String> { +impl WhereStackElement { + pub fn get_clause(&self) -> Result<&WhereCondition, String> { match self { - WhereTreeElement::_Node(node) => Ok(node), - _ => Err(format!("Found edge when expected node")), + WhereStackElement::Condition(condition) => Ok(condition), + WhereStackElement::_LogicalOperator(_) => Err(format!("Logical operator cannot be used as a condition")), } } } #[derive(Debug, PartialEq)] pub enum LogicalOperator { - And, - Or, + _Not, + _And, + _Or, } #[derive(Debug, PartialEq)] diff --git a/src/cli/ast/select_statement.rs b/src/cli/ast/select_statement.rs index 18c3095..72fa919 100644 --- a/src/cli/ast/select_statement.rs +++ b/src/cli/ast/select_statement.rs @@ -1,6 +1,6 @@ use crate::{cli::{ ast::{ - parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereTreeNode, + parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereStackElement, helpers::{ common::{expect_token_type, tokens_to_identifier_list, get_table_name}, order_by_clause::get_order_by, where_clause::get_where_clause, limit_clause::get_limit @@ -14,7 +14,7 @@ pub fn build(parser: &mut Parser) -> Result { let columns = get_columns(parser)?; let table_name = get_table_name(parser)?; parser.advance()?; - let where_clause: Option = get_where_clause(parser)?; + let where_clause: Option> = get_where_clause(parser)?; let order_by_clause = get_order_by(parser)?; let limit_clause = get_limit(parser)?; @@ -49,8 +49,8 @@ mod tests { use crate::cli::ast::OrderByClause; use crate::cli::ast::OrderByDirection; use crate::cli::ast::LimitClause; - use crate::cli::ast::WhereTreeElement; - use crate::cli::ast::WhereTreeEdge; + use crate::cli::ast::WhereStackElement; + use crate::cli::ast::WhereCondition; use crate::cli::ast::test_utils::token; #[test] @@ -165,16 +165,13 @@ mod tests { columns: SelectStatementColumns::Specific(vec![ "id".to_string(), ]), - where_clause: Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { column: "id".to_string(), operator: Operator::Equals, value: Value::Integer(1), - }))), - right: Box::new(None), - operator: None, - negation: false, - }), + }), + ]), order_by_clause: Some(vec![ OrderByClause { column: "id".to_string(), diff --git a/src/cli/ast/update_statement.rs b/src/cli/ast/update_statement.rs index bd8059d..d9bbfbf 100644 --- a/src/cli/ast/update_statement.rs +++ b/src/cli/ast/update_statement.rs @@ -60,9 +60,8 @@ mod tests { use super::*; use crate::db::table::Value; use crate::cli::ast::Operator; - use crate::cli::ast::WhereTreeNode; - use crate::cli::ast::WhereTreeElement; - use crate::cli::ast::WhereTreeEdge; + use crate::cli::ast::WhereStackElement; + use crate::cli::ast::WhereCondition; use crate::cli::ast::test_utils::token; #[test] @@ -118,16 +117,13 @@ mod tests { column: "column".to_string(), value: Value::Integer(1), }], - where_clause: Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { - column: "id".to_string(), + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), operator: Operator::GreaterThan, value: Value::Integer(2), - }))), - right: Box::new(None), - operator: None, - negation: false, - }), + }), + ]), }); assert_eq!(statement, expected); } @@ -169,17 +165,14 @@ mod tests { value: Value::Text("False".to_string()), }, ], - where_clause: Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { column: "id".to_string(), operator: Operator::Equals, value: Value::Integer(3), - }))), - right: Box::new(None), - operator: None, - negation: false, - }), - }); + }), + ]), + }); assert_eq!(statement, expected); } } \ No newline at end of file diff --git a/src/db/table/select/mod.rs b/src/db/table/select/mod.rs index 2015f6b..f52e80f 100644 --- a/src/db/table/select/mod.rs +++ b/src/db/table/select/mod.rs @@ -23,11 +23,10 @@ pub fn select(table: &Table, statement: SelectStatement) -> Result Result>, String> { let mut rows: Vec> = vec![]; - if let Some(where_tree_node) = &statement.where_clause { + if let Some(where_stack) = &statement.where_clause { // This will need to be updated once we have multiple conditions working properly - let where_tree_element = where_tree_node.left.as_ref(); - let where_clause = match where_tree_element { - Some(where_tree_element) => where_tree_element.get_clause()?, + let where_clause = match where_stack.first() { + Some(where_clause) => where_clause.get_clause()?, _ => return Err(format!("Found nothing when expected edge")), }; @@ -69,9 +68,8 @@ mod tests { use super::*; use crate::db::table::{Table, Value, DataType, ColumnDefinition}; use crate::cli::ast::{SelectStatementColumns, LimitClause, OrderByClause, OrderByDirection, Operator}; - use crate::cli::ast::WhereTreeNode; - use crate::cli::ast::WhereTreeElement; - use crate::cli::ast::WhereTreeEdge; + use crate::cli::ast::WhereStackElement; + use crate::cli::ast::WhereCondition; fn default_table() -> Table { Table { @@ -139,16 +137,13 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::All, - where_clause: Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { column: "name".to_string(), operator: Operator::Equals, value: Value::Text("John".to_string()), - }))), - right: Box::new(None), - operator: None, - negation: false, - }), + }), + ]), order_by_clause: None, limit_clause: None, }; @@ -166,16 +161,13 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::Specific(vec!["name".to_string(), "age".to_string()]), - where_clause: Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { column: "money".to_string(), operator: Operator::Equals, value: Value::Real(1000.0), - }))), - right: Box::new(None), - operator: None, - negation: false, - }), + }), + ]), order_by_clause: None, limit_clause: None, }; @@ -214,16 +206,13 @@ mod tests { let statement = SelectStatement { table_name: "users".to_string(), columns: SelectStatementColumns::All, - where_clause: Some(WhereTreeNode { - left: Box::new(Some(WhereTreeElement::Edge(WhereTreeEdge { + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { column: "column_not_included".to_string(), operator: Operator::Equals, value: Value::Text("John".to_string()), - }))), - right: Box::new(None), - operator: None, - negation: false, - }), + }), + ]), order_by_clause: None, limit_clause: None, }; diff --git a/src/db/table/select/where_clause.rs b/src/db/table/select/where_clause.rs index b7bd1f0..3bd1156 100644 --- a/src/db/table/select/where_clause.rs +++ b/src/db/table/select/where_clause.rs @@ -1,7 +1,7 @@ -use crate::cli::ast::{Operator, WhereTreeEdge}; +use crate::cli::ast::{Operator, WhereCondition}; use crate::db::table::{Table, Value, DataType}; -pub fn matches_where_clause(table: &Table, row: &Vec, where_clause: &WhereTreeEdge) -> bool { +pub fn matches_where_clause(table: &Table, row: &Vec, where_clause: &WhereCondition) -> bool { let column_value = table.get_column_from_row(row, &where_clause.column); if column_value.get_type() != where_clause.value.get_type() { return false; @@ -58,7 +58,7 @@ mod tests { }, ]); let row = vec![Value::Integer(1)]; - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; assert!(matches_where_clause(&table, &row, &where_clause)); } @@ -68,7 +68,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Integer(2)]; - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::Equals,value:Value::Integer(1)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -82,7 +82,7 @@ mod tests { }, ]); let row = vec![Value::Integer(1)]; - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::Equals,value:Value::Text("Fletcher".to_string())}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::Equals,value:Value::Text("Fletcher".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -92,15 +92,15 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Integer(10)]; - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::GreaterThan,value:Value::Integer(0)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::GreaterThan,value:Value::Integer(0)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(0)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(0)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::LessThan,value:Value::Integer(20)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::LessThan,value:Value::Integer(20)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::LessEquals,value:Value::Integer(20)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::LessEquals,value:Value::Integer(20)}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::NotEquals,value:Value::Integer(10)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::NotEquals,value:Value::Integer(10)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -110,17 +110,17 @@ mod tests { ColumnDefinition {name:"name".to_string(),data_type:DataType::Text, constraints: vec![] }, ]); let row = vec![Value::Text("lop".to_string())]; - let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::GreaterEquals,value:Value::Text("abc".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::GreaterEquals,value:Value::Text("abc".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::LessEquals,value:Value::Text("lop".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::LessEquals,value:Value::Text("lop".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::GreaterThan,value:Value::Text("xyz".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::GreaterThan,value:Value::Text("xyz".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::LessThan,value:Value::Text("abc".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::LessThan,value:Value::Text("abc".to_string())}; assert!(!matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::NotEquals,value:Value::Text("abc".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::NotEquals,value:Value::Text("abc".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); - let where_clause = WhereTreeEdge {column:"name".to_string(),operator:Operator::Equals,value:Value::Text("lop".to_string())}; + let where_clause = WhereCondition {column:"name".to_string(),operator:Operator::Equals,value:Value::Text("lop".to_string())}; assert!(matches_where_clause(&table, &row, &where_clause)); } @@ -130,7 +130,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Integer, constraints: vec![] }, ]); let row = vec![Value::Null]; - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(1)}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Integer(1)}; assert!(!matches_where_clause(&table, &row, &where_clause)); } @@ -140,7 +140,7 @@ mod tests { ColumnDefinition {name:"id".to_string(),data_type:DataType::Blob, constraints: vec![] }, ]); let row = vec![Value::Blob(vec![1, 2, 3])]; - let where_clause = WhereTreeEdge {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Blob(vec![1, 2, 3])}; + let where_clause = WhereCondition {column:"id".to_string(),operator:Operator::GreaterEquals,value:Value::Blob(vec![1, 2, 3])}; assert!(!matches_where_clause(&table, &row, &where_clause)); } } \ No newline at end of file From 6b3bb9ee430fc3f700c307795990d23c5eb75edb Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Thu, 4 Sep 2025 09:12:59 -0400 Subject: [PATCH 4/9] Add description of algorithm --- src/cli/ast/helpers/where_clause.rs | 43 ++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index 4658491..e6000aa 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -1,6 +1,29 @@ -use crate::cli::ast::{parser::Parser, WhereStackElement, WhereCondition, Operator, helpers::common::{expect_token_type, token_to_value}}; +use crate::cli::ast::{ + helpers::common::{expect_token_type, token_to_value}, + parser::Parser, + Operator, WhereCondition, WhereStackElement, +}; use crate::cli::tokenizer::token::TokenTypes; +// This returns a tree of WhereTreeElements, which can be either a WhereTreeEdge or a WhereTreeNode. +// A WhereTreeNode is meant to represent the conditions in a logical operator. +// A WhereTreeEdge is meant to represent a single condition with a column, operator, and a value. +// WhereTreeEdges are only meant to be leaves in the tree, reading the nodes via an in-order traversal +// represents the tree in the correct order of operations as we are expected to parse. + +// To build this tree, we first create a root node which is what is eventually turned into +// the WhereStack and is then returned to the caller. We use a Stack to represent the current node in the tree. +// With the root node being the first element in the stack. If we encounter a Logical Operator, the current +// node's operator is set to the logical operator, and we push a new node to the right arm of the tree with the +// condition being the left arm of the new node. Encountering a Right Paren '(' causes us to push a node +// to the right arm of the current node with the old_node being pushed into the stack and then the new node +// being the new current node. Encountering a Left Paren ')' causes us to pop the stack and set the old_node +// as the current node. This process repeats until we have read all the conditions in the where clause. +// Encountering a NOT is done by pushing an additional node with the operator being the negation operator. +// Only one arm of the negation node contains a condition which is always the left arm. + +// The WhereStack is a representation of the tree in the correct order of operations using Reverse Polish Notation. + pub fn get_where_clause(parser: &mut Parser) -> Result>, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); @@ -9,7 +32,7 @@ pub fn get_where_clause(parser: &mut Parser) -> Result Result { @@ -19,7 +42,7 @@ fn get_where_clause_edge(parser: &mut Parser) -> Result parser.advance()?; let token = parser.current_token()?; - let operator = match token.token_type { + let operator = match token.token_type { TokenTypes::Equals => Operator::Equals, TokenTypes::NotEquals => Operator::NotEquals, TokenTypes::LessThan => Operator::LessThan, @@ -33,19 +56,19 @@ fn get_where_clause_edge(parser: &mut Parser) -> Result let value = token_to_value(parser)?; parser.advance()?; - return Ok(WhereCondition { - column: column, - operator: operator, - value: value, - }); + Ok(WhereCondition { + column, + operator, + value, + }) } #[cfg(test)] mod tests { use super::*; + use crate::cli::ast::LogicalOperator; use crate::cli::tokenizer::scanner::Token; use crate::db::table::Value; - use crate::cli::ast::LogicalOperator; fn token(tt: TokenTypes, val: &'static str) -> Token<'static> { Token { @@ -116,7 +139,7 @@ mod tests { column: "id".to_string(), operator: Operator::Equals, value: Value::Integer(1), - }), + }), WhereStackElement::Condition(WhereCondition { column: "name".to_string(), operator: Operator::Equals, From f23bc8e023a1c16f9c08e0313388d05d4975fcb0 Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Thu, 4 Sep 2025 14:01:26 -0400 Subject: [PATCH 5/9] Update comment with new strategy --- src/cli/ast/helpers/where_clause.rs | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index e6000aa..6c1c77c 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -5,24 +5,12 @@ use crate::cli::ast::{ }; use crate::cli::tokenizer::token::TokenTypes; -// This returns a tree of WhereTreeElements, which can be either a WhereTreeEdge or a WhereTreeNode. -// A WhereTreeNode is meant to represent the conditions in a logical operator. -// A WhereTreeEdge is meant to represent a single condition with a column, operator, and a value. -// WhereTreeEdges are only meant to be leaves in the tree, reading the nodes via an in-order traversal -// represents the tree in the correct order of operations as we are expected to parse. - -// To build this tree, we first create a root node which is what is eventually turned into -// the WhereStack and is then returned to the caller. We use a Stack to represent the current node in the tree. -// With the root node being the first element in the stack. If we encounter a Logical Operator, the current -// node's operator is set to the logical operator, and we push a new node to the right arm of the tree with the -// condition being the left arm of the new node. Encountering a Right Paren '(' causes us to push a node -// to the right arm of the current node with the old_node being pushed into the stack and then the new node -// being the new current node. Encountering a Left Paren ')' causes us to pop the stack and set the old_node -// as the current node. This process repeats until we have read all the conditions in the where clause. -// Encountering a NOT is done by pushing an additional node with the operator being the negation operator. -// Only one arm of the negation node contains a condition which is always the left arm. - -// The WhereStack is a representation of the tree in the correct order of operations using Reverse Polish Notation. +// The WhereStack is a the method that is used to store the order of operations with Reverse Polish Notation. +// This is built from the infix expression of the where clause. Using the shunting yard algorithm. Thanks Djikstra! +// Operator precedence is given as '()' > 'NOT' > 'AND' > 'OR' +// This is currently represented as stack of LogicalOperators and WhereConditions. +// WhereConditions are currently represented as 'column operator value' +// This will later be expanded to replace the WhereConditions with a generalized evaluation function. pub fn get_where_clause(parser: &mut Parser) -> Result>, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { From ec822d435ed3c1689958c611063b0b34b762afd5 Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Thu, 4 Sep 2025 19:48:41 -0400 Subject: [PATCH 6/9] Update where_clause func to properly fetch items --- src/cli/ast/helpers/where_clause.rs | 93 +++++++++++++++++++++-------- src/cli/ast/mod.rs | 19 ++---- src/db/table/select/mod.rs | 5 +- 3 files changed, 75 insertions(+), 42 deletions(-) diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index 6c1c77c..72ce461 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -1,7 +1,7 @@ use crate::cli::ast::{ helpers::common::{expect_token_type, token_to_value}, parser::Parser, - Operator, WhereCondition, WhereStackElement, + Operator, WhereCondition, WhereStackElement, LogicalOperator, }; use crate::cli::tokenizer::token::TokenTypes; @@ -12,43 +12,84 @@ use crate::cli::tokenizer::token::TokenTypes; // WhereConditions are currently represented as 'column operator value' // This will later be expanded to replace the WhereConditions with a generalized evaluation function. +enum _ValidTokensforWhereCondition { + Identifier, + Equals, + NotEquals, + LessThan, + LessEquals, + GreaterThan, + GreaterEquals, + LeftParen, + RightParen, + Not, + And, + Or, +} + pub fn get_where_clause(parser: &mut Parser) -> Result>, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); } parser.advance()?; - let where_condition = get_where_clause_edge(parser)?; + let where_condition = get_where_condition(parser)?; - Ok(Some(vec![WhereStackElement::Condition(where_condition)])) + Ok(Some(vec![where_condition])) } -fn get_where_clause_edge(parser: &mut Parser) -> Result { +fn get_where_condition(parser: &mut Parser) -> Result { let token = parser.current_token()?; - expect_token_type(parser, TokenTypes::Identifier)?; - let column = token.value.to_string(); - parser.advance()?; + match token.token_type { + TokenTypes::And => { + parser.advance()?; + return Ok(WhereStackElement::LogicalOperator(LogicalOperator::And)) + }, + TokenTypes::Or => { + parser.advance()?; + return Ok(WhereStackElement::LogicalOperator(LogicalOperator::Or)) + }, + TokenTypes::Not => { + parser.advance()?; + return Ok(WhereStackElement::LogicalOperator(LogicalOperator::Not)) + }, + TokenTypes::LeftParen => { + parser.advance()?; + return Ok(WhereStackElement::LogicalOperator(LogicalOperator::LeftParen)) + }, + TokenTypes::RightParen => { + parser.advance()?; + return Ok(WhereStackElement::LogicalOperator(LogicalOperator::RightParen)) + }, + TokenTypes::Identifier => { + let column = token.value.to_string(); + parser.advance()?; - let token = parser.current_token()?; - let operator = match token.token_type { - TokenTypes::Equals => Operator::Equals, - TokenTypes::NotEquals => Operator::NotEquals, - TokenTypes::LessThan => Operator::LessThan, - TokenTypes::LessEquals => Operator::LessEquals, - TokenTypes::GreaterThan => Operator::GreaterThan, - TokenTypes::GreaterEquals => Operator::GreaterEquals, - _ => return Err(parser.format_error()), - }; - parser.advance()?; + let token = parser.current_token()?; + let operator = match token.token_type { + TokenTypes::Equals => Operator::Equals, + TokenTypes::NotEquals => Operator::NotEquals, + TokenTypes::LessThan => Operator::LessThan, + TokenTypes::LessEquals => Operator::LessEquals, + TokenTypes::GreaterThan => Operator::GreaterThan, + TokenTypes::GreaterEquals => Operator::GreaterEquals, + _ => return Err(parser.format_error()), + }; + parser.advance()?; - let value = token_to_value(parser)?; - parser.advance()?; + let value = token_to_value(parser)?; + parser.advance()?; - Ok(WhereCondition { - column, - operator, - value, - }) + return Ok(WhereStackElement::Condition( + WhereCondition { + column, + operator, + value, + }) + ); + } + _ => return Err(parser.format_error()), + } } #[cfg(test)] @@ -133,7 +174,7 @@ mod tests { operator: Operator::Equals, value: Value::Text("John".to_string()), }), - WhereStackElement::_LogicalOperator(LogicalOperator::_And), + WhereStackElement::LogicalOperator(LogicalOperator::And), ]); // TEST IS NOT WORKING YET DUE TO A CHANGE IN THE AST // assert_eq!(expected, where_clause); diff --git a/src/cli/ast/mod.rs b/src/cli/ast/mod.rs index 62f69d0..aef7d5e 100644 --- a/src/cli/ast/mod.rs +++ b/src/cli/ast/mod.rs @@ -98,23 +98,16 @@ pub struct WhereCondition { #[derive(Debug, PartialEq)] pub enum WhereStackElement { Condition(WhereCondition), - _LogicalOperator(LogicalOperator), -} - -impl WhereStackElement { - pub fn get_clause(&self) -> Result<&WhereCondition, String> { - match self { - WhereStackElement::Condition(condition) => Ok(condition), - WhereStackElement::_LogicalOperator(_) => Err(format!("Logical operator cannot be used as a condition")), - } - } + LogicalOperator(LogicalOperator), } #[derive(Debug, PartialEq)] pub enum LogicalOperator { - _Not, - _And, - _Or, + Not, + And, + Or, + LeftParen, + RightParen, } #[derive(Debug, PartialEq)] diff --git a/src/db/table/select/mod.rs b/src/db/table/select/mod.rs index f52e80f..31d6367 100644 --- a/src/db/table/select/mod.rs +++ b/src/db/table/select/mod.rs @@ -2,8 +2,7 @@ pub mod where_clause; pub mod limit_clause; pub mod order_by_clause; use crate::db::table::{Table, Value}; -use crate::cli::ast::SelectStatement; -use crate::cli::ast::SelectStatementColumns; +use crate::cli::ast::{SelectStatement, WhereStackElement, SelectStatementColumns}; use crate::db::table::common::validate_and_clone_row; @@ -26,7 +25,7 @@ pub fn get_initial_rows(table: &Table, statement: &SelectStatement) -> Result where_clause.get_clause()?, + Some(WhereStackElement::Condition(where_clause)) => where_clause, _ => return Err(format!("Found nothing when expected edge")), }; From 2fb32ef2a72fb29b301fe1c11c6623540ca3f663 Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Thu, 4 Sep 2025 20:21:25 -0400 Subject: [PATCH 7/9] Shunting yard algorithm potentially working --- src/cli/ast/helpers/where_clause.rs | 122 ++++++++++++++++++++-------- src/cli/ast/mod.rs | 29 ++++++- 2 files changed, 117 insertions(+), 34 deletions(-) diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index 72ce461..c0c3ee7 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -1,7 +1,7 @@ use crate::cli::ast::{ helpers::common::{expect_token_type, token_to_value}, parser::Parser, - Operator, WhereCondition, WhereStackElement, LogicalOperator, + Operator, WhereCondition, WhereStackElement, LogicalOperator, Parentheses, WhereStackOperators, }; use crate::cli::tokenizer::token::TokenTypes; @@ -12,54 +12,113 @@ use crate::cli::tokenizer::token::TokenTypes; // WhereConditions are currently represented as 'column operator value' // This will later be expanded to replace the WhereConditions with a generalized evaluation function. -enum _ValidTokensforWhereCondition { - Identifier, - Equals, - NotEquals, - LessThan, - LessEquals, - GreaterThan, - GreaterEquals, - LeftParen, - RightParen, - Not, - And, - Or, -} - pub fn get_where_clause(parser: &mut Parser) -> Result>, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); } parser.advance()?; + let mut where_stack: Vec = vec![]; - let where_condition = get_where_condition(parser)?; + let mut operator_stack: Vec = vec![]; - Ok(Some(vec![where_condition])) + loop { + let where_condition = get_where_condition(parser)?; + match where_condition { + Some(where_stack_element) => { + match where_stack_element { + WhereStackElement::Condition(where_condition) => { + where_stack.push(WhereStackElement::Condition(where_condition)); + }, + WhereStackElement::Parentheses(parentheses) => { + match parentheses { + Parentheses::Left => { + operator_stack.push(WhereStackOperators::Parentheses(parentheses)); + }, + Parentheses::Right => { + loop { + let current_operator = operator_stack.pop(); + if let Some (current_operator) = current_operator { + match current_operator { + WhereStackOperators::Parentheses(Parentheses::Left) => { + break; + }, + WhereStackOperators::LogicalOperator(logical_operator) => { + where_stack.push(WhereStackElement::LogicalOperator(logical_operator)); + }, + WhereStackOperators::Parentheses(Parentheses::Right) => { + return Err("Mismatched parentheses found.".to_string()); + }, + } + } + else { + return Err("Mismatched parentheses found.".to_string()); + } + } + }, + } + }, + WhereStackElement::LogicalOperator(logical_operator) => { + loop { + let current_operator = if let Some(operator) = operator_stack.pop() { + operator + } else { + operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); + break; + }; + match current_operator { + WhereStackOperators::LogicalOperator(current_operator) => { + if logical_operator.is_greater_precedence(¤t_operator) { + operator_stack.push(WhereStackOperators::LogicalOperator(current_operator)); + operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); + break; + } + else { + where_stack.push(WhereStackElement::LogicalOperator(current_operator)); + } + }, + _ => return Err("Mismatched parentheses found.".to_string()), + } + } + } + } + } + None => break + } + } + while let Some(operator) = operator_stack.pop() { + match operator { + WhereStackOperators::LogicalOperator(logical_operator) => { + where_stack.push(WhereStackElement::LogicalOperator(logical_operator)); + }, + _ => return Err("Mismatched parentheses found.".to_string()), + } + } + + Ok(Some(where_stack)) } -fn get_where_condition(parser: &mut Parser) -> Result { +fn get_where_condition(parser: &mut Parser) -> Result, String> { let token = parser.current_token()?; match token.token_type { TokenTypes::And => { parser.advance()?; - return Ok(WhereStackElement::LogicalOperator(LogicalOperator::And)) + return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::And))) }, TokenTypes::Or => { parser.advance()?; - return Ok(WhereStackElement::LogicalOperator(LogicalOperator::Or)) + return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::Or))) }, TokenTypes::Not => { parser.advance()?; - return Ok(WhereStackElement::LogicalOperator(LogicalOperator::Not)) + return Ok(Some(WhereStackElement::LogicalOperator(LogicalOperator::Not))) }, TokenTypes::LeftParen => { parser.advance()?; - return Ok(WhereStackElement::LogicalOperator(LogicalOperator::LeftParen)) + return Ok(Some(WhereStackElement::Parentheses(Parentheses::Left))) }, TokenTypes::RightParen => { parser.advance()?; - return Ok(WhereStackElement::LogicalOperator(LogicalOperator::RightParen)) + return Ok(Some(WhereStackElement::Parentheses(Parentheses::Right))) }, TokenTypes::Identifier => { let column = token.value.to_string(); @@ -80,15 +139,15 @@ fn get_where_condition(parser: &mut Parser) -> Result let value = token_to_value(parser)?; parser.advance()?; - return Ok(WhereStackElement::Condition( + return Ok(Some(WhereStackElement::Condition( WhereCondition { column, operator, value, }) - ); + )); } - _ => return Err(parser.format_error()), + _ => return Ok(None), } } @@ -162,8 +221,8 @@ mod tests { let mut parser = Parser::new(tokens); let result = get_where_clause(&mut parser); assert!(result.is_ok()); - let _where_clause = result.unwrap(); - let _expected = Some(vec![ + let where_clause = result.unwrap(); + let expected = Some(vec![ WhereStackElement::Condition(WhereCondition { column: "id".to_string(), operator: Operator::Equals, @@ -176,8 +235,7 @@ mod tests { }), WhereStackElement::LogicalOperator(LogicalOperator::And), ]); - // TEST IS NOT WORKING YET DUE TO A CHANGE IN THE AST - // assert_eq!(expected, where_clause); - // assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); } } diff --git a/src/cli/ast/mod.rs b/src/cli/ast/mod.rs index aef7d5e..bdfa76e 100644 --- a/src/cli/ast/mod.rs +++ b/src/cli/ast/mod.rs @@ -99,6 +99,12 @@ pub struct WhereCondition { pub enum WhereStackElement { Condition(WhereCondition), LogicalOperator(LogicalOperator), + Parentheses(Parentheses), +} + +pub enum WhereStackOperators { + LogicalOperator(LogicalOperator), + Parentheses(Parentheses), } #[derive(Debug, PartialEq)] @@ -106,8 +112,27 @@ pub enum LogicalOperator { Not, And, Or, - LeftParen, - RightParen, +} + +impl LogicalOperator { + pub fn is_greater_precedence(&self, other: &LogicalOperator) -> bool { + match (self, other) { + (LogicalOperator::Not, LogicalOperator::Not) => false, + (LogicalOperator::Not, _) => true, + (LogicalOperator::And, LogicalOperator::Not) => false, + (LogicalOperator::And, LogicalOperator::And) => false, + (LogicalOperator::And, LogicalOperator::Or) => true, + (LogicalOperator::Or, LogicalOperator::Not) => false, + (LogicalOperator::Or, LogicalOperator::And) => false, + (LogicalOperator::Or, LogicalOperator::Or) => false, + } + } +} + +#[derive(Debug, PartialEq)] +pub enum Parentheses { + Left, + Right, } #[derive(Debug, PartialEq)] From 3af1018c74bdf4f00d840c54c0881037a527aa0a Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Thu, 4 Sep 2025 20:32:05 -0400 Subject: [PATCH 8/9] Add more tests with different order of operations --- src/cli/ast/helpers/where_clause.rs | 96 +++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index c0c3ee7..2770557 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -238,4 +238,100 @@ mod tests { assert_eq!(expected, where_clause); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); } + + #[test] + fn where_clause_with_not_logical_operators_is_generated_correctly() { + // WHERE NOT id = 1 AND name = "John" OR age > 20; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "age"), + token(TokenTypes::GreaterThan, ">"), + token(TokenTypes::IntLiteral, "20"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::LogicalOperator(LogicalOperator::And), + WhereStackElement::Condition(WhereCondition { + column: "age".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(20), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + ]); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_parentheses_is_generated_correctly() { + // WHERE id = 1 OR NOT name = "John" AND NOT age > 20; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::Identifier, "age"), + token(TokenTypes::GreaterThan, ">"), + token(TokenTypes::IntLiteral, "20"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::Condition(WhereCondition { + column: "age".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(20), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::LogicalOperator(LogicalOperator::And), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + ]); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } } From 6d20fe7aa94bc4e5f9368bb655a8960d35ab10ee Mon Sep 17 00:00:00 2001 From: Fletcher555 Date: Thu, 4 Sep 2025 21:03:00 -0400 Subject: [PATCH 9/9] Add error tracking for invalid WHERE statements --- src/cli/ast/helpers/where_clause.rs | 257 +++++++++++++++++++++++++++- src/cli/ast/parser.rs | 12 ++ 2 files changed, 260 insertions(+), 9 deletions(-) diff --git a/src/cli/ast/helpers/where_clause.rs b/src/cli/ast/helpers/where_clause.rs index 2770557..0df3090 100644 --- a/src/cli/ast/helpers/where_clause.rs +++ b/src/cli/ast/helpers/where_clause.rs @@ -8,18 +8,25 @@ use crate::cli::tokenizer::token::TokenTypes; // The WhereStack is a the method that is used to store the order of operations with Reverse Polish Notation. // This is built from the infix expression of the where clause. Using the shunting yard algorithm. Thanks Djikstra! // Operator precedence is given as '()' > 'NOT' > 'AND' > 'OR' -// This is currently represented as stack of LogicalOperators and WhereConditions. +// This is currently represented as stack of LogicalOperators, WhereConditions. // WhereConditions are currently represented as 'column operator value' // This will later be expanded to replace the WhereConditions with a generalized evaluation function. +// We also validate the order using an enum of the current next expected token types. +#[derive(PartialEq, Debug)] +enum WhereClauseExpectedNextToken { + ConditionLeftParenNot, + LogicalOperatorRightParen, +} + pub fn get_where_clause(parser: &mut Parser) -> Result>, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); } parser.advance()?; let mut where_stack: Vec = vec![]; - let mut operator_stack: Vec = vec![]; + let mut expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; loop { let where_condition = get_where_condition(parser)?; @@ -27,14 +34,25 @@ pub fn get_where_clause(parser: &mut Parser) -> Result { match where_stack_element { WhereStackElement::Condition(where_condition) => { + if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } + expected_next_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; where_stack.push(WhereStackElement::Condition(where_condition)); }, WhereStackElement::Parentheses(parentheses) => { match parentheses { Parentheses::Left => { + if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } operator_stack.push(WhereStackOperators::Parentheses(parentheses)); }, Parentheses::Right => { + if expected_next_token != WhereClauseExpectedNextToken::LogicalOperatorRightParen { + return Err(parser.format_error_nearby()); + } + expected_next_token = WhereClauseExpectedNextToken::LogicalOperatorRightParen; loop { let current_operator = operator_stack.pop(); if let Some (current_operator) = current_operator { @@ -58,6 +76,20 @@ pub fn get_where_clause(parser: &mut Parser) -> Result { + match logical_operator { + LogicalOperator::Not => { + if expected_next_token != WhereClauseExpectedNextToken::ConditionLeftParenNot { + return Err(parser.format_error_nearby()); + } + expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + } + _ => { + if expected_next_token != WhereClauseExpectedNextToken::LogicalOperatorRightParen { + return Err(parser.format_error_nearby()); + } + expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot; + } + } loop { let current_operator = if let Some(operator) = operator_stack.pop() { operator @@ -76,7 +108,11 @@ pub fn get_where_clause(parser: &mut Parser) -> Result return Err("Mismatched parentheses found.".to_string()), + _ => { + operator_stack.push(current_operator); + operator_stack.push(WhereStackOperators::LogicalOperator(logical_operator)); + break; + }, } } } @@ -220,8 +256,6 @@ mod tests { ]; let mut parser = Parser::new(tokens); let result = get_where_clause(&mut parser); - assert!(result.is_ok()); - let where_clause = result.unwrap(); let expected = Some(vec![ WhereStackElement::Condition(WhereCondition { column: "id".to_string(), @@ -235,6 +269,8 @@ mod tests { }), WhereStackElement::LogicalOperator(LogicalOperator::And), ]); + assert!(result.is_ok()); + let where_clause = result.unwrap(); assert_eq!(expected, where_clause); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); } @@ -260,8 +296,6 @@ mod tests { ]; let mut parser = Parser::new(tokens); let result = get_where_clause(&mut parser); - assert!(result.is_ok()); - let where_clause = result.unwrap(); let expected = Some(vec![ WhereStackElement::Condition(WhereCondition { column: "id".to_string(), @@ -282,12 +316,14 @@ mod tests { }), WhereStackElement::LogicalOperator(LogicalOperator::Or), ]); + assert!(result.is_ok()); + let where_clause = result.unwrap(); assert_eq!(expected, where_clause); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); } #[test] - fn where_clause_with_parentheses_is_generated_correctly() { + fn where_clause_with_different_precedence_is_generated_correctly() { // WHERE id = 1 OR NOT name = "John" AND NOT age > 20; let tokens = vec![ token(TokenTypes::Where, "WHERE"), @@ -308,8 +344,62 @@ mod tests { ]; let mut parser = Parser::new(tokens); let result = get_where_clause(&mut parser); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::Condition(WhereCondition { + column: "age".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(20), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::LogicalOperator(LogicalOperator::And), + WhereStackElement::LogicalOperator(LogicalOperator::Or), + ]); assert!(result.is_ok()); let where_clause = result.unwrap(); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_parentheses_is_generated_correctly() { + // WHERE (id = 1 OR name = "John") AND NOT (age > 20 OR active = 0); + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "age"), + token(TokenTypes::GreaterThan, ">"), + token(TokenTypes::IntLiteral, "20"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "active"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "0"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); let expected = Some(vec![ WhereStackElement::Condition(WhereCondition { column: "id".to_string(), @@ -321,17 +411,166 @@ mod tests { operator: Operator::Equals, value: Value::Text("John".to_string()), }), - WhereStackElement::LogicalOperator(LogicalOperator::Not), + WhereStackElement::LogicalOperator(LogicalOperator::Or), WhereStackElement::Condition(WhereCondition { column: "age".to_string(), operator: Operator::GreaterThan, value: Value::Integer(20), }), + WhereStackElement::Condition(WhereCondition { + column: "active".to_string(), + operator: Operator::Equals, + value: Value::Integer(0), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Or), WhereStackElement::LogicalOperator(LogicalOperator::Not), WhereStackElement::LogicalOperator(LogicalOperator::And), + ]); + println!("{:?}", result); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + assert_eq!(expected, where_clause); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_nested_parentheses_and_logical_operators_is_generated_correctly() { + // WHERE (id = 1 OR NOT (name = "John" AND age > 20)); + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Identifier, "age"), + token(TokenTypes::GreaterThan, ">"), + token(TokenTypes::IntLiteral, "20"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + let expected = Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::Condition(WhereCondition { + column: "name".to_string(), + operator: Operator::Equals, + value: Value::Text("John".to_string()), + }), + WhereStackElement::Condition(WhereCondition { + column: "age".to_string(), + operator: Operator::GreaterThan, + value: Value::Integer(20), + }), + WhereStackElement::LogicalOperator(LogicalOperator::And), + WhereStackElement::LogicalOperator(LogicalOperator::Not), WhereStackElement::LogicalOperator(LogicalOperator::Or), ]); + assert!(result.is_ok()); + let where_clause = result.unwrap(); assert_eq!(expected, where_clause); assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); } + + #[test] + fn where_clause_with_invalid_parentheses_is_generated_correctly() { + // WHERE (id = 1 OR name = "John"; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Mismatched parentheses found."); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_invalid_right_paren_is_generated_correctly() { + // WHERE (id = 1 OR name = "John")); + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::LeftParen, "("), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Or, "OR"), + token(TokenTypes::Identifier, "name"), + token(TokenTypes::Equals, "="), + token(TokenTypes::String, "John"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::RightParen, ")"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Mismatched parentheses found."); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_valid_not_logical_operator_is_generated_correctly() { + // WHERE NOT id = 1; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_ok()); + let where_clause = result.unwrap(); + assert_eq!(where_clause, Some(vec![ + WhereStackElement::Condition(WhereCondition { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + WhereStackElement::LogicalOperator(LogicalOperator::Not), + ])); + assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon); + } + + #[test] + fn where_clause_with_invalid_not_logical_operator_is_generated_correctly() { + // WHERE NOT AND id = 1; + let tokens = vec![ + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Not, "NOT"), + token(TokenTypes::And, "AND"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = get_where_clause(&mut parser); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Error near line 1, column 0"); + } } diff --git a/src/cli/ast/parser.rs b/src/cli/ast/parser.rs index 4ebc884..f84163c 100644 --- a/src/cli/ast/parser.rs +++ b/src/cli/ast/parser.rs @@ -55,6 +55,18 @@ impl<'a> Parser<'a> { } } + pub fn format_error_nearby(&self) -> String { + if self.current < self.tokens.len() { + let token = &self.tokens[self.current]; + return format!( + "Error near line {:?}, column {:?}", + token.line_num, token.col_num + ); + } else { + return "Error at end of input.".to_string(); + } + } + pub fn next_statement(&mut self, builder: &dyn StatementBuilder) -> Option> { match self.current_token() { Ok(token) => match token.token_type {