diff --git a/src/cli/ast/common.rs b/src/cli/ast/common.rs index 8422c3e..bbc36e8 100644 --- a/src/cli/ast/common.rs +++ b/src/cli/ast/common.rs @@ -1,4 +1,4 @@ -use crate::cli::{ast::{parser::Parser, WhereClause, Operator}, tokenizer::token::TokenTypes}; +use crate::cli::{ast::{parser::Parser, WhereClause, Operator, OrderByClause, OrderByDirection, LimitClause}, tokenizer::token::TokenTypes}; use crate::db::table::Value; use hex::decode; @@ -55,6 +55,14 @@ pub fn tokens_to_identifier_list(parser: &mut Parser) -> Result, Str return Ok(identifiers); } +pub fn get_table_name(parser: &mut Parser) -> Result { + parser.advance()?; + let token = parser.current_token()?; + expect_token_type(parser, TokenTypes::Identifier)?; + let result = token.value.to_string(); + Ok(result) +} + pub fn get_where_clause(parser: &mut Parser) -> Result, String> { if expect_token_type(parser, TokenTypes::Where).is_err() { return Ok(None); @@ -86,4 +94,82 @@ pub fn get_where_clause(parser: &mut Parser) -> Result, Stri operator: operator, value: value, })); +} + + +pub fn get_order_by(parser: &mut Parser) -> Result>, String> { + if expect_token_type(parser, TokenTypes::Order).is_err() { + return Ok(None); + } + parser.advance()?; + + expect_token_type(parser, TokenTypes::By)?; + parser.advance()?; + + let mut order_by_clauses = vec![]; + loop { + let token = parser.current_token()?; + expect_token_type(parser, TokenTypes::Identifier)?; + let column = token.value.to_string(); + parser.advance()?; + + let token = parser.current_token()?; + let direction = match token.token_type { + TokenTypes::Asc => { + parser.advance()?; + OrderByDirection::Asc + }, + TokenTypes::Desc => { + parser.advance()?; + OrderByDirection::Desc + }, + _ => OrderByDirection::Asc, + }; + + order_by_clauses.push(OrderByClause { + column: column, + direction: direction, + }); + + let token = parser.current_token()?; + if token.token_type != TokenTypes::Comma { + break; + } + parser.advance()?; + } + return Ok(Some(order_by_clauses)); +} + +pub fn get_limit(parser: &mut Parser) -> Result, String> { + if expect_token_type(parser, TokenTypes::Limit).is_err() { + return Ok(None); + } + parser.advance()?; + + expect_token_type(parser, TokenTypes::IntLiteral)?; + let limit = token_to_value(parser)?; + parser.advance()?; + + let token = parser.current_token()?; + if token.token_type != TokenTypes::Offset { + return Ok(Some(LimitClause { + limit: limit, + offset: None, + })); + } + parser.advance()?; + + expect_token_type(parser, TokenTypes::IntLiteral)?; + let offset = token_to_value(parser)?; + if let Value::Integer(offset) = offset { + if offset < 0 { + return Err(parser.format_error()); + } + }; + parser.advance()?; + + return Ok(Some(LimitClause { + limit: limit, + offset: Some(offset), + })); } \ No newline at end of file diff --git a/src/cli/ast/delete_statement.rs b/src/cli/ast/delete_statement.rs new file mode 100644 index 0000000..e4f3d82 --- /dev/null +++ b/src/cli/ast/delete_statement.rs @@ -0,0 +1,106 @@ +use crate::cli::ast::{parser::Parser, SqlStatement, DeleteStatement, common::{expect_token_type, get_table_name, get_where_clause, get_order_by, get_limit}}; +use crate::cli::tokenizer::token::TokenTypes; + +pub fn build(parser: &mut Parser) -> Result { + parser.advance()?; + expect_token_type(parser, TokenTypes::From)?; + let table_name = get_table_name(parser)?; + parser.advance()?; + let where_clause = get_where_clause(parser)?; + let order_by_clause = get_order_by(parser)?; + let limit_clause = get_limit(parser)?; + + return Ok(SqlStatement::DeleteStatement(DeleteStatement { + table_name: table_name, + where_clause: where_clause, + order_by_clause: order_by_clause, + limit_clause: limit_clause, + })); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cli::tokenizer::scanner::Token; + use crate::cli::ast::OrderByClause; + use crate::cli::ast::OrderByDirection; + use crate::cli::ast::LimitClause; + use crate::cli::ast::Operator; + use crate::cli::ast::WhereClause; + use crate::db::table::Value; + + fn token(tt: TokenTypes, val: &'static str) -> Token<'static> { + Token { + token_type: tt, + value: val, + col_num: 0, + line_num: 1, + } + } + + #[test] + fn delete_statement_with_all_tokens_is_generated_correctly() { + // DELETE FROM users; + let tokens = vec![ + token(TokenTypes::Delete, "DELETE"), + token(TokenTypes::From, "FROM"), + token(TokenTypes::Identifier, "users"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = build(&mut parser); + assert!(result.is_ok()); + let statement = result.unwrap(); + let expected = SqlStatement::DeleteStatement(DeleteStatement { + table_name: "users".to_string(), + where_clause: None, + order_by_clause: None, + limit_clause: None, + }); + assert_eq!(expected, statement); + } + + #[test] + fn delete_statement_with_all_clauses_is_generated_correctly() { + // DELETE FROM users WHERE id = 1 ORDER BY id ASC LIMIT 10 OFFSET 5; + let tokens = vec![ + token(TokenTypes::Delete, "DELETE"), + token(TokenTypes::From, "FROM"), + token(TokenTypes::Identifier, "users"), + token(TokenTypes::Where, "WHERE"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Equals, "="), + token(TokenTypes::IntLiteral, "1"), + token(TokenTypes::Order, "ORDER"), + token(TokenTypes::By, "BY"), + token(TokenTypes::Identifier, "id"), + token(TokenTypes::Asc, "ASC"), + token(TokenTypes::Limit, "LIMIT"), + token(TokenTypes::IntLiteral, "10"), + token(TokenTypes::Offset, "OFFSET"), + token(TokenTypes::IntLiteral, "5"), + token(TokenTypes::SemiColon, ";"), + ]; + let mut parser = Parser::new(tokens); + let result = build(&mut parser); + assert!(result.is_ok()); + let statement = result.unwrap(); + let expected = SqlStatement::DeleteStatement(DeleteStatement { + table_name: "users".to_string(), + where_clause: Some(WhereClause { + column: "id".to_string(), + operator: Operator::Equals, + value: Value::Integer(1), + }), + order_by_clause: Some(vec![OrderByClause { + column: "id".to_string(), + direction: OrderByDirection::Asc, + }]), + limit_clause: Some(LimitClause { + limit: Value::Integer(10), + offset: Some(Value::Integer(5)), + }), + }); + assert_eq!(expected, statement); + } +} \ No newline at end of file diff --git a/src/cli/ast/mod.rs b/src/cli/ast/mod.rs index dfa2f96..25708b8 100644 --- a/src/cli/ast/mod.rs +++ b/src/cli/ast/mod.rs @@ -7,6 +7,7 @@ mod insert_statement; mod parser; mod select_statement; mod update_statement; +mod delete_statement; #[derive(Debug, PartialEq)] pub enum SqlStatement { @@ -14,6 +15,7 @@ pub enum SqlStatement { InsertInto(InsertIntoStatement), Select(SelectStatement), UpdateStatement(UpdateStatement), + DeleteStatement(DeleteStatement), } #[derive(Debug, PartialEq)] @@ -38,6 +40,14 @@ pub struct SelectStatement { pub limit_clause: Option, } +#[derive(Debug, PartialEq)] +pub struct DeleteStatement { + pub table_name: String, + pub where_clause: Option, + pub order_by_clause: Option>, + pub limit_clause: Option, +} + #[derive(Debug, PartialEq)] pub struct UpdateStatement { pub table_name: String, @@ -106,6 +116,7 @@ pub trait StatementBuilder { fn build_insert(&self, parser: &mut parser::Parser) -> Result; fn build_select(&self, parser: &mut parser::Parser) -> Result; fn build_update(&self, parser: &mut parser::Parser) -> Result; + fn build_delete(&self, parser: &mut parser::Parser) -> Result; } pub struct DefaultStatementBuilder; @@ -126,6 +137,10 @@ impl StatementBuilder for DefaultStatementBuilder { fn build_update(&self, parser: &mut parser::Parser) -> Result { update_statement::build(parser) } + + fn build_delete(&self, parser: &mut parser::Parser) -> Result { + delete_statement::build(parser) + } } pub fn generate(tokens: Vec) -> Vec> { diff --git a/src/cli/ast/parser.rs b/src/cli/ast/parser.rs index 70fdb71..f414d4b 100644 --- a/src/cli/ast/parser.rs +++ b/src/cli/ast/parser.rs @@ -64,6 +64,7 @@ impl<'a> Parser<'a> { TokenTypes::Insert => Some(builder.build_insert(self)), TokenTypes::Select => Some(builder.build_select(self)), TokenTypes::Update => Some(builder.build_update(self)), + TokenTypes::Delete => Some(builder.build_delete(self)), TokenTypes::EOF => None, _ => { Some(Err(self.format_error())) @@ -142,6 +143,10 @@ mod tests { fn build_update(&self, _parser: &mut Parser) -> Result { todo!(); } + + fn build_delete(&self, _parser: &mut Parser) -> Result { + todo!(); + } } #[test] diff --git a/src/cli/ast/select_statement.rs b/src/cli/ast/select_statement.rs index 26bf286..7e36450 100644 --- a/src/cli/ast/select_statement.rs +++ b/src/cli/ast/select_statement.rs @@ -1,9 +1,10 @@ -use crate::{cli::{ast::{common::{expect_token_type, get_where_clause, token_to_value, tokens_to_identifier_list}, parser::Parser, LimitClause, OrderByClause, OrderByDirection, SelectStatement, SelectStatementColumns, SqlStatement, WhereClause}, tokenizer::token::TokenTypes}, db::table::Value}; +use crate::{cli::{ast::{common::{expect_token_type, get_where_clause, tokens_to_identifier_list, get_order_by, get_limit, get_table_name}, parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereClause}, tokenizer::token::TokenTypes}}; pub fn build(parser: &mut Parser) -> Result { parser.advance()?; let columns = get_columns(parser)?; let table_name = get_table_name(parser)?; + parser.advance()?; let where_clause: Option = get_where_clause(parser)?; let order_by_clause = get_order_by(parser)?; let limit_clause = get_limit(parser)?; @@ -31,101 +32,15 @@ fn get_columns(parser: &mut Parser) -> Result { } } -fn get_table_name(parser: &mut Parser) -> Result { - parser.advance()?; - let token = parser.current_token()?; - expect_token_type(parser, TokenTypes::Identifier)?; - - let result = token.value.to_string(); - parser.advance()?; - Ok(result) -} - -fn get_order_by(parser: &mut Parser) -> Result>, String> { - if expect_token_type(parser, TokenTypes::Order).is_err() { - return Ok(None); - } - parser.advance()?; - - expect_token_type(parser, TokenTypes::By)?; - parser.advance()?; - - let mut order_by_clauses = vec![]; - loop { - let token = parser.current_token()?; - expect_token_type(parser, TokenTypes::Identifier)?; - let column = token.value.to_string(); - parser.advance()?; - - let token = parser.current_token()?; - let direction = match token.token_type { - TokenTypes::Asc => { - parser.advance()?; - OrderByDirection::Asc - }, - TokenTypes::Desc => { - parser.advance()?; - OrderByDirection::Desc - }, - _ => OrderByDirection::Asc, - }; - - order_by_clauses.push(OrderByClause { - column: column, - direction: direction, - }); - - let token = parser.current_token()?; - if token.token_type != TokenTypes::Comma { - break; - } - parser.advance()?; - } - return Ok(Some(order_by_clauses)); -} - -fn get_limit(parser: &mut Parser) -> Result, String> { - if expect_token_type(parser, TokenTypes::Limit).is_err() { - return Ok(None); - } - parser.advance()?; - - expect_token_type(parser, TokenTypes::IntLiteral)?; - let limit = token_to_value(parser)?; - parser.advance()?; - - let token = parser.current_token()?; - if token.token_type != TokenTypes::Offset { - return Ok(Some(LimitClause { - limit: limit, - offset: None, - })); - } - parser.advance()?; - - expect_token_type(parser, TokenTypes::IntLiteral)?; - let offset = token_to_value(parser)?; - if let Value::Integer(offset) = offset { - if offset < 0 { - return Err(parser.format_error()); - } - }; - parser.advance()?; - - return Ok(Some(LimitClause { - limit: limit, - offset: Some(offset), - })); - - -} - #[cfg(test)] mod tests { use super::*; use crate::cli::ast::Operator; use crate::cli::tokenizer::scanner::Token; use crate::db::table::Value; + use crate::cli::ast::OrderByClause; + use crate::cli::ast::OrderByDirection; + use crate::cli::ast::LimitClause; fn token(tt: TokenTypes, val: &'static str) -> Token<'static> { Token { diff --git a/src/cli/ast/update_statement.rs b/src/cli/ast/update_statement.rs index 91a14d7..64ea716 100644 --- a/src/cli/ast/update_statement.rs +++ b/src/cli/ast/update_statement.rs @@ -1,4 +1,4 @@ -use crate::cli::ast::{parser::Parser, SqlStatement, UpdateStatement, ColumnValue, common::{expect_token_type, get_where_clause, token_to_value}}; +use crate::cli::ast::{parser::Parser, SqlStatement, UpdateStatement, ColumnValue, common::{expect_token_type, get_where_clause, token_to_value, get_table_name}}; use crate::cli::tokenizer::token::TokenTypes; pub fn build(parser: &mut Parser) -> Result { @@ -19,14 +19,6 @@ pub fn build(parser: &mut Parser) -> Result { })); } -fn get_table_name(parser: &mut Parser) -> Result { - parser.advance()?; - let token = parser.current_token()?; - expect_token_type(parser, TokenTypes::Identifier)?; - let result = token.value.to_string(); - Ok(result) -} - // We do not currently support conditional updates such as "UPDATE table SET column = column * 1.1;" fn get_update_values(parser: &mut Parser) -> Result, String> { parser.advance()?; diff --git a/src/db/database.rs b/src/db/database.rs index 6f749d2..969a09c 100644 --- a/src/db/database.rs +++ b/src/db/database.rs @@ -32,6 +32,9 @@ impl Database { SqlStatement::UpdateStatement(_statement) => { todo!(); }, + SqlStatement::DeleteStatement(_statement) => { + todo!(); + }, } }