diff --git a/src/cli/ast/mod.rs b/src/cli/ast/mod.rs index 168a845..96e880a 100644 --- a/src/cli/ast/mod.rs +++ b/src/cli/ast/mod.rs @@ -55,6 +55,8 @@ pub struct UpdateStatement { pub table_name: String, pub update_values: Vec, pub where_clause: Option>, + pub order_by_clause: Option>, + pub limit_clause: Option, } #[derive(Debug, PartialEq)] diff --git a/src/cli/ast/update_statement.rs b/src/cli/ast/update_statement.rs index 0dbec86..e16c8fa 100644 --- a/src/cli/ast/update_statement.rs +++ b/src/cli/ast/update_statement.rs @@ -1,6 +1,7 @@ use crate::cli::ast::{ parser::Parser, SqlStatement, UpdateStatement, ColumnValue, - helpers::common::{expect_token_type, token_to_value, get_table_name} + helpers::common::{expect_token_type, token_to_value, get_table_name}, + helpers::{order_by_clause::get_order_by, limit_clause::get_limit}, }; use crate::cli::ast::helpers::where_stack::get_where_clause; use crate::cli::tokenizer::token::TokenTypes; @@ -13,6 +14,8 @@ pub fn build(parser: &mut Parser) -> Result { expect_token_type(parser, TokenTypes::Set)?; let update_values = get_update_values(parser)?; let where_clause = get_where_clause(parser)?; + let order_by_clause = get_order_by(parser)?; + let limit_clause = get_limit(parser)?; // Ensure SemiColon expect_token_type(parser, TokenTypes::SemiColon)?; @@ -20,6 +23,8 @@ pub fn build(parser: &mut Parser) -> Result { table_name: table_name, update_values: update_values, where_clause: where_clause, + order_by_clause: order_by_clause, + limit_clause: limit_clause, })); } @@ -64,6 +69,9 @@ mod tests { use crate::cli::ast::WhereCondition; use crate::cli::ast::test_utils::token; use crate::cli::ast::Operand; + use crate::cli::ast::OrderByClause; + use crate::cli::ast::OrderByDirection; + use crate::cli::ast::LimitClause; #[test] fn update_statement_with_all_tokens_is_generated_correctly() { @@ -88,6 +96,8 @@ mod tests { value: Value::Text("value".to_string()), }], where_clause: None, + order_by_clause: None, + limit_clause: None, }); assert_eq!(statement, expected); } @@ -125,6 +135,8 @@ mod tests { r_side: Operand::Value(Value::Integer(2)), }), ]), + order_by_clause: None, + limit_clause: None, }); assert_eq!(statement, expected); } @@ -173,7 +185,62 @@ mod tests { r_side: Operand::Value(Value::Integer(3)), }), ]), + order_by_clause: None, + limit_clause: None, }); assert_eq!(statement, expected); } + + #[test] + fn update_statement_with_all_clauses_is_generated_correctly() { + // UPDATE users SET column = 1 WHERE id = 1 ORDER BY id ASC LIMIT 10 OFFSET 5; + let tokens = vec![ + token(TokenTypes::Update, "UPDATE"), + token(TokenTypes::Identifier, "users"), + token(TokenTypes::Set, "SET"), + token(TokenTypes::Identifier, "column"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Order, "ORDER"), + token(TokenTypes::By, "BY"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Asc, "ASC"), + token(TokenTypes::Limit, "LIMIT"), + token(TokenTypes::IntLiteral, "10"), + token(TokenTypes::Offset, "OFFSET"), + token(TokenTypes::IntLiteral, "5"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = build(&mut parser); + assert!(result.is_ok()); + let statement = result.unwrap(); + let expected = SqlStatement::UpdateStatement(UpdateStatement { + table_name: "users".to_string(), + update_values: vec![ColumnValue { + column: "column".to_string(), + value: Value::Integer(1), + }], + where_clause: Some(vec![ + WhereStackElement::Condition(WhereCondition { + l_side: Operand::Identifier("id".to_string()), + operator: Operator::Equals, + r_side: Operand::Value(Value::Integer(1)), + }), + ]), + order_by_clause: Some(vec![OrderByClause { + column: "id".to_string(), + direction: OrderByDirection::Asc, + }]), + limit_clause: Some(LimitClause { + limit: Value::Integer(10), + offset: Some(Value::Integer(5)), + }), + }); + assert_eq!(expected, statement); + } } \ No newline at end of file