diff --git a/src/db/database.rs b/src/db/database.rs index 44c747a..e172570 100644 --- a/src/db/database.rs +++ b/src/db/database.rs @@ -49,6 +49,21 @@ impl Database { self.alter_table(statement)?; Ok(None) } + SqlStatement::BeginTransaction(_statement) => { + todo!() + } + SqlStatement::Commit => { + todo!() + } + SqlStatement::Rollback(_statement) => { + todo!() + } + SqlStatement::Savepoint(_statement) => { + todo!() + } + SqlStatement::Release(_statement) => { + todo!() + } } } diff --git a/src/interpreter/ast/mod.rs b/src/interpreter/ast/mod.rs index 81c3674..3abb847 100644 --- a/src/interpreter/ast/mod.rs +++ b/src/interpreter/ast/mod.rs @@ -10,6 +10,8 @@ mod delete_statement; mod helpers; mod drop_statement; mod alter_table_statement; +mod statement_builder; +mod transaction_statements; #[cfg(test)] mod test_utils; @@ -29,6 +31,11 @@ pub enum SqlStatement { DeleteStatement(DeleteStatement), DropTable(DropTableStatement), AlterTable(AlterTableStatement), + BeginTransaction(BeginStatement), + Commit, + Rollback(RollbackStatement), + Savepoint(SavepointStatement), + Release(ReleaseStatement), } #[derive(Debug, PartialEq)] @@ -136,6 +143,29 @@ pub enum AlterTableAction { DropColumn { column_name: String }, } +#[derive(Debug, PartialEq)] +pub enum BeginStatement { + Deferred, + Immediate, + Exclusive, +} + +#[derive(Debug, PartialEq)] +pub struct RollbackStatement { + pub savepoint_name: Option, +} + +#[derive(Debug, PartialEq)] +pub struct SavepointStatement { + pub savepoint_name: String, +} + +#[derive(Debug, PartialEq)] +pub struct ReleaseStatement { + pub savepoint_name: String, +} + + #[derive(Debug, PartialEq)] pub struct ColumnValue { pub column: String, @@ -264,52 +294,11 @@ pub struct LimitClause { pub offset: Option, } -pub trait StatementBuilder { - fn build_create(&self, parser: &mut parser::Parser) -> Result; - fn build_insert(&self, parser: &mut parser::Parser) -> Result; - fn build_select(&self, parser: &mut parser::Parser) -> Result; - fn build_update(&self, parser: &mut parser::Parser) -> Result; - fn build_delete(&self, parser: &mut parser::Parser) -> Result; - fn build_drop(&self, parser: &mut parser::Parser) -> Result; - fn build_alter(&self, parser: &mut parser::Parser) -> Result; -} -pub struct DefaultStatementBuilder; - -impl StatementBuilder for DefaultStatementBuilder { - fn build_create(&self, parser: &mut parser::Parser) -> Result { - create_statement::build(parser) - } - - fn build_insert(&self, parser: &mut parser::Parser) -> Result { - insert_statement::build(parser) - } - - fn build_select(&self, parser: &mut parser::Parser) -> Result { - select_statement_stack::build(parser) - } - - fn build_update(&self, parser: &mut parser::Parser) -> Result { - update_statement::build(parser) - } - - fn build_delete(&self, parser: &mut parser::Parser) -> Result { - delete_statement::build(parser) - } - - fn build_drop(&self, parser: &mut parser::Parser) -> Result { - drop_statement::build(parser) - } - - fn build_alter(&self, parser: &mut parser::Parser) -> Result { - alter_table_statement::build(parser) - } -} pub fn generate(tokens: Vec) -> Vec> { let mut results: Vec> = vec![]; let mut parser = parser::Parser::new(tokens); - let builder : &dyn StatementBuilder = &DefaultStatementBuilder; loop { let line_num = match parser.line_num() { Ok(line_num) => line_num, @@ -318,7 +307,7 @@ pub fn generate(tokens: Vec) -> Vec> break; } }; - let next_statement = parser.next_statement(builder); + let next_statement = parser.next_statement(); if let Some(next_statement) = next_statement { match next_statement { Err(error) => { diff --git a/src/interpreter/ast/parser.rs b/src/interpreter/ast/parser.rs index a2ba7d1..a4c3faa 100644 --- a/src/interpreter/ast/parser.rs +++ b/src/interpreter/ast/parser.rs @@ -1,13 +1,15 @@ use crate::interpreter::{ - ast::{SqlStatement, StatementBuilder}, - ast::helpers::token::format_statement_tokens, - tokenizer::scanner::Token, tokenizer::token::TokenTypes + ast::{helpers::token::format_statement_tokens, SqlStatement, statement_builder::{StatementBuilder, DefaultStatementBuilder}}, + tokenizer::scanner::Token, tokenizer::token::TokenTypes }; + + pub struct Parser<'a> { tokens: Vec>, start: usize, current: usize, + builder: &'a dyn StatementBuilder, } impl<'a> Parser<'a> { @@ -16,6 +18,7 @@ impl<'a> Parser<'a> { tokens, start: 0, current: 0, + builder: &DefaultStatementBuilder{}, }; } @@ -85,17 +88,22 @@ impl<'a> Parser<'a> { } } - pub fn next_statement(&mut self, builder: &dyn StatementBuilder) -> Option> { + pub fn next_statement(&mut self) -> Option> { self.start = self.current; match (&self.current_token(), &self.peek_token()) { (Ok(token), Ok(peek_token)) => match (&token.token_type, &peek_token.token_type) { - (TokenTypes::Create, _) => Some(builder.build_create(self)), - (TokenTypes::Insert, _) => Some(builder.build_insert(self)), - (TokenTypes::Select, _) | (TokenTypes::LeftParen, TokenTypes::Select) => Some(builder.build_select(self)), - (TokenTypes::Update, _) => Some(builder.build_update(self)), - (TokenTypes::Delete, _) => Some(builder.build_delete(self)), - (TokenTypes::Drop, _) => Some(builder.build_drop(self)), - (TokenTypes::Alter, _) => Some(builder.build_alter(self)), + (TokenTypes::Create, _) => Some(self.builder.build_create(self)), + (TokenTypes::Insert, _) => Some(self.builder.build_insert(self)), + (TokenTypes::Select, _) | (TokenTypes::LeftParen, TokenTypes::Select) => Some(self.builder.build_select(self)), + (TokenTypes::Update, _) => Some(self.builder.build_update(self)), + (TokenTypes::Delete, _) => Some(self.builder.build_delete(self)), + (TokenTypes::Drop, _) => Some(self.builder.build_drop(self)), + (TokenTypes::Alter, _) => Some(self.builder.build_alter(self)), + (TokenTypes::Begin, _) => Some(self.builder.build_begin(self)), + (TokenTypes::Commit, _) | (TokenTypes::End, _) => Some(self.builder.build_commit(self)), + (TokenTypes::Rollback, _) => Some(self.builder.build_rollback(self)), + (TokenTypes::Savepoint, _) => Some(self.builder.build_savepoint(self)), + (TokenTypes::Release, _) => Some(self.builder.build_release(self)), _ => { Some(Err(self.format_error())) } @@ -115,6 +123,7 @@ mod tests { use super::*; use crate::interpreter::ast::{CreateTableStatement, InsertIntoStatement, SelectStatement, SelectStatementColumns, SelectStatementStack, SelectStatementStackElement, SelectMode}; use crate::interpreter::ast::test_utils::{token_with_location, token}; + use crate::interpreter::ast::statement_builder::MockStatementBuilder; #[test] fn parser_formats_error_when_at_end_of_input() { @@ -132,64 +141,6 @@ mod tests { assert_eq!(result, "Error at line 3, column 15: Unexpected value: INSERT"); } - pub struct MockStatementBuilder; - - impl StatementBuilder for MockStatementBuilder { - fn build_create(&self, parser: &mut Parser) -> Result { - parser.advance()?; - parser.advance_past_semicolon()?; - return Ok(SqlStatement::CreateTable(CreateTableStatement { - table_name: "users".to_string(), - existence_check: None, - columns: vec![], - })); - } - - fn build_insert(&self, parser: &mut Parser) -> Result { - parser.advance()?; - parser.advance_past_semicolon()?; - return Ok(SqlStatement::InsertInto(InsertIntoStatement { - table_name: "users".to_string(), - columns: None, - values: vec![], - })); - } - - fn build_select(&self, parser: &mut Parser) -> Result { - parser.advance()?; - parser.advance_past_semicolon()?; - return Ok(SqlStatement::Select(SelectStatementStack { - columns: SelectStatementColumns::All, - elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { - table_name: "users".to_string(), - mode: SelectMode::All, - columns: SelectStatementColumns::All, - where_clause: None, - order_by_clause: None, - limit_clause: None, - })], - order_by_clause: None, - limit_clause: None, - })); - } - - fn build_update(&self, _parser: &mut Parser) -> Result { - todo!(); - } - - fn build_delete(&self, _parser: &mut Parser) -> Result { - todo!(); - } - - fn build_drop(&self, _parser: &mut Parser) -> Result { - todo!(); - } - - fn build_alter(&self, _parser: &mut Parser) -> Result { - todo!(); - } - } - #[test] fn parser_next_statement_filters_options_correctly_handles_multiple_statements() { let tokens = vec![ @@ -201,10 +152,14 @@ mod tests { token(TokenTypes::SemiColon, ";"), token(TokenTypes::EOF, ""), ]; - let mut parser = Parser::new(tokens); - let builder : &dyn StatementBuilder = &MockStatementBuilder; + let mut parser = Parser { + tokens, + start: 0, + current: 0, + builder: &MockStatementBuilder, + }; // Create Table - let result = parser.next_statement(builder); + let result = parser.next_statement(); let expected = Some(Ok(SqlStatement::CreateTable(CreateTableStatement { table_name: "users".to_string(), existence_check: None, @@ -213,7 +168,7 @@ mod tests { assert_eq!(result, expected); // Insert Into - let result = parser.next_statement(builder); + let result = parser.next_statement(); let expected = Some(Ok(SqlStatement::InsertInto(InsertIntoStatement { table_name: "users".to_string(), columns: None, @@ -222,7 +177,7 @@ mod tests { assert_eq!(result, expected); // Select - let result = parser.next_statement(builder); + let result = parser.next_statement(); let expected = Some(Ok(SqlStatement::Select(SelectStatementStack { columns: SelectStatementColumns::All, elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { @@ -239,7 +194,7 @@ mod tests { assert_eq!(result, expected); // EOF - let result = parser.next_statement(builder); + let result = parser.next_statement(); let expected = None; assert_eq!(result, expected); } @@ -251,9 +206,13 @@ mod tests { token(TokenTypes::SemiColon, ";"), token(TokenTypes::EOF, ""), ]; - let mut parser = Parser::new(tokens); - let builder : &dyn StatementBuilder = &MockStatementBuilder; - let result = parser.next_statement(builder); + let mut parser = Parser { + tokens, + start: 0, + current: 0, + builder: &MockStatementBuilder, + }; + let result = parser.next_statement(); let expected = Some(Err("Error at line 1, column 0: Unexpected value: users".to_string())); assert_eq!(result, expected); } diff --git a/src/interpreter/ast/statement_builder.rs b/src/interpreter/ast/statement_builder.rs new file mode 100644 index 0000000..780599f --- /dev/null +++ b/src/interpreter/ast/statement_builder.rs @@ -0,0 +1,152 @@ +use crate::interpreter::ast::{create_statement, insert_statement, select_statement_stack, update_statement, delete_statement, drop_statement, alter_table_statement, transaction_statements}; +use crate::interpreter::ast::parser::Parser; +use crate::interpreter::ast::SqlStatement; + +pub trait StatementBuilder { + fn build_create(&self, parser: &mut Parser) -> Result; + fn build_insert(&self, parser: &mut Parser) -> Result; + fn build_select(&self, parser: &mut Parser) -> Result; + fn build_update(&self, parser: &mut Parser) -> Result; + fn build_delete(&self, parser: &mut Parser) -> Result; + fn build_drop(&self, parser: &mut Parser) -> Result; + fn build_alter(&self, parser: &mut Parser) -> Result; + fn build_begin(&self, parser: &mut Parser) -> Result; + fn build_commit(&self, parser: &mut Parser) -> Result; + fn build_rollback(&self, parser: &mut Parser) -> Result; + fn build_savepoint(&self, parser: &mut Parser) -> Result; + fn build_release(&self, parser: &mut Parser) -> Result; +} + +pub struct DefaultStatementBuilder; + +impl StatementBuilder for DefaultStatementBuilder { + fn build_create(&self, parser: &mut Parser) -> Result { + create_statement::build(parser) + } + + fn build_insert(&self, parser: &mut Parser) -> Result { + insert_statement::build(parser) + } + + fn build_select(&self, parser: &mut Parser) -> Result { + select_statement_stack::build(parser) + } + + fn build_update(&self, parser: &mut Parser) -> Result { + update_statement::build(parser) + } + + fn build_delete(&self, parser: &mut Parser) -> Result { + delete_statement::build(parser) + } + + fn build_drop(&self, parser: &mut Parser) -> Result { + drop_statement::build(parser) + } + + fn build_alter(&self, parser: &mut Parser) -> Result { + alter_table_statement::build(parser) + } + + fn build_begin(&self, parser: &mut Parser) -> Result { + transaction_statements::build_begin(parser) + } + + fn build_commit(&self, parser: &mut Parser) -> Result { + transaction_statements::build_commit(parser) + } + + fn build_rollback(&self, parser: &mut Parser) -> Result { + transaction_statements::build_rollback(parser) + } + + fn build_savepoint(&self, parser: &mut Parser) -> Result { + transaction_statements::build_savepoint(parser) + } + + fn build_release(&self, parser: &mut Parser) -> Result { + transaction_statements::build_release(parser) + } +} + +#[cfg(test)] +pub struct MockStatementBuilder; +#[cfg(test)] +use crate::interpreter::ast::{CreateTableStatement, InsertIntoStatement, SelectStatementStack, SelectStatementColumns, SelectStatementStackElement, SelectStatement, SelectMode}; + +#[cfg(test)] +impl StatementBuilder for MockStatementBuilder { + fn build_create(&self, parser: &mut Parser) -> Result { + parser.advance()?; + parser.advance_past_semicolon()?; + return Ok(SqlStatement::CreateTable(CreateTableStatement { + table_name: "users".to_string(), + existence_check: None, + columns: vec![], + })); + } + + fn build_insert(&self, parser: &mut Parser) -> Result { + parser.advance()?; + parser.advance_past_semicolon()?; + return Ok(SqlStatement::InsertInto(InsertIntoStatement { + table_name: "users".to_string(), + columns: None, + values: vec![], + })); + } + + fn build_select(&self, parser: &mut Parser) -> Result { + parser.advance()?; + parser.advance_past_semicolon()?; + return Ok(SqlStatement::Select(SelectStatementStack { + columns: SelectStatementColumns::All, + elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { + table_name: "users".to_string(), + mode: SelectMode::All, + columns: SelectStatementColumns::All, + where_clause: None, + order_by_clause: None, + limit_clause: None, + })], + order_by_clause: None, + limit_clause: None, + })); + } + + fn build_update(&self, _parser: &mut Parser) -> Result { + todo!(); + } + + fn build_delete(&self, _parser: &mut Parser) -> Result { + todo!(); + } + + fn build_drop(&self, _parser: &mut Parser) -> Result { + todo!(); + } + + fn build_alter(&self, _parser: &mut Parser) -> Result { + todo!(); + } + + fn build_begin(&self, _parser: &mut Parser) -> Result { + todo!(); + } + + fn build_commit(&self, _parser: &mut Parser) -> Result { + todo!(); + } + + fn build_rollback(&self, _parser: &mut Parser) -> Result { + todo!(); + } + + fn build_savepoint(&self, _parser: &mut Parser) -> Result { + todo!(); + } + + fn build_release(&self, _parser: &mut Parser) -> Result { + todo!(); + } +} \ No newline at end of file diff --git a/src/interpreter/ast/transaction_statements.rs b/src/interpreter/ast/transaction_statements.rs new file mode 100644 index 0000000..66699fd --- /dev/null +++ b/src/interpreter/ast/transaction_statements.rs @@ -0,0 +1,191 @@ +use crate::interpreter::ast::{parser::Parser, SqlStatement, BeginStatement, RollbackStatement, SavepointStatement, ReleaseStatement}; +use crate::interpreter::tokenizer::token::TokenTypes; +use crate::interpreter::ast::helpers::token::expect_token_type; + +pub fn build_begin(parser: &mut Parser) -> Result { + parser.advance()?; + let statement = if expect_token_type(parser, TokenTypes::Deferred).is_ok() || expect_token_type(parser, TokenTypes::SemiColon).is_ok() { + SqlStatement::BeginTransaction(BeginStatement::Deferred) + } else if expect_token_type(parser, TokenTypes::Exclusive).is_ok() { + SqlStatement::BeginTransaction(BeginStatement::Exclusive) + } + else if expect_token_type(parser, TokenTypes::Immediate).is_ok() { + SqlStatement::BeginTransaction(BeginStatement::Immediate) + } + else { + return Err(parser.format_error()); + }; + if parser.current_token()?.token_type != TokenTypes::SemiColon { + parser.advance()?; + expect_token_type(parser, TokenTypes::SemiColon)?; + } + return Ok(statement); +} + +pub fn build_commit(parser: &mut Parser) -> Result { + parser.advance()?; + expect_token_type(parser, TokenTypes::SemiColon)?; + return Ok(SqlStatement::Commit); +} + +pub fn build_rollback(parser: &mut Parser) -> Result { + parser.advance()?; + let name = if expect_token_type(parser, TokenTypes::To).is_ok() { + parser.advance()?; + expect_token_type(parser, TokenTypes::Savepoint)?; + parser.advance()?; + expect_token_type(parser, TokenTypes::Identifier)?; + let name = parser.current_token()?.value.to_string(); + parser.advance()?; + Some(name) + } else { + None + }; + expect_token_type(parser, TokenTypes::SemiColon)?; + return Ok(SqlStatement::Rollback(RollbackStatement { + savepoint_name: name, + })); +} + +pub fn build_savepoint(parser: &mut Parser) -> Result { + parser.advance()?; + expect_token_type(parser, TokenTypes::Identifier)?; + let savepoint_name = parser.current_token()?.value.to_string(); + parser.advance()?; + expect_token_type(parser, TokenTypes::SemiColon)?; + return Ok(SqlStatement::Savepoint(SavepointStatement { + savepoint_name: savepoint_name, + })); +} + +pub fn build_release(parser: &mut Parser) -> Result { + parser.advance()?; + expect_token_type(parser, TokenTypes::Savepoint)?; + parser.advance()?; + expect_token_type(parser, TokenTypes::Identifier)?; + let savepoint_name = parser.current_token()?.value.to_string(); + parser.advance()?; + expect_token_type(parser, TokenTypes::SemiColon)?; + return Ok(SqlStatement::Release(ReleaseStatement { + savepoint_name: savepoint_name, + })); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::interpreter::ast::test_utils::token; + + #[test] + fn build_begin_with_all_tokens_is_generated_correctly() { + // BEGIN DEFERRED; BEGIN EXCLUSIVE; BEGIN IMMEDIATE; BEGIN; + let begin_tokens = vec! [ + token(TokenTypes::Begin, "BEGIN"), + token(TokenTypes::Deferred, "DEFERRED"), + token(TokenTypes::SemiColon, ";"), + token(TokenTypes::Begin, "BEGIN"), + token(TokenTypes::Exclusive, "EXCLUSIVE"), + token(TokenTypes::SemiColon, ";"), + token(TokenTypes::Begin, "BEGIN"), + token(TokenTypes::Immediate, "IMMEDIATE"), + token(TokenTypes::SemiColon, ";"), + token(TokenTypes::Begin, "BEGIN"), + token(TokenTypes::SemiColon, ";"), + ]; + let expected = vec![ + Some(Ok(SqlStatement::BeginTransaction(BeginStatement::Deferred))), + Some(Ok(SqlStatement::BeginTransaction(BeginStatement::Exclusive))), + Some(Ok(SqlStatement::BeginTransaction(BeginStatement::Immediate))), + Some(Ok(SqlStatement::BeginTransaction(BeginStatement::Deferred))) + ]; + let mut parser = Parser::new(begin_tokens); + for i in 0..3 { + let result = parser.next_statement(); + assert_eq!(expected[i], result); + let _ = parser.advance_past_semicolon(); + } + } + + #[test] + fn build_commit_with_all_tokens_is_generated_correctly() { + // COMMIT; END; + let commit_tokens = vec![ + token(TokenTypes::Commit, "COMMIT"), + token(TokenTypes::SemiColon, ";"), + token(TokenTypes::End, "END"), + token(TokenTypes::SemiColon, ";"), + ]; + let expected = vec![ + Some(Ok(SqlStatement::Commit)), + Some(Ok(SqlStatement::Commit)), + ]; + let mut parser = Parser::new(commit_tokens); + for i in 0..2 { + let result = parser.next_statement(); + assert_eq!(expected[i], result); + let _ = parser.advance_past_semicolon(); + } + } + + #[test] + fn build_rollback_with_all_tokens_is_generated_correctly() { + // ROLLBACK; ROLLBACK TO savepoint_name; + let rollback_tokens = vec![ + token(TokenTypes::Rollback, "ROLLBACK"), + token(TokenTypes::SemiColon, ";"), + token(TokenTypes::Rollback, "ROLLBACK"), + token(TokenTypes::To, "TO"), + token(TokenTypes::Savepoint, "SAVEPOINT"), + token(TokenTypes::Identifier, "savepoint_name"), + token(TokenTypes::SemiColon, ";"), + ]; + let expected = vec![ + Some(Ok(SqlStatement::Rollback(RollbackStatement { + savepoint_name: None, + }))), + Some(Ok(SqlStatement::Rollback(RollbackStatement { + savepoint_name: Some("savepoint_name".to_string()), + }))), + ]; + let mut parser = Parser::new(rollback_tokens); + for i in 0..2 { + let result = parser.next_statement(); + assert_eq!(expected[i], result); + let _ = parser.advance_past_semicolon(); + } + } + + #[test] + fn build_savepoint_with_all_tokens_is_generated_correctly() { + // SAVEPOINT savepoint_name; + let savepoint_tokens = vec![ + token(TokenTypes::Savepoint, "SAVEPOINT"), + token(TokenTypes::Identifier, "savepoint_name"), + token(TokenTypes::SemiColon, ";"), + ]; + let expected = + Some(Ok(SqlStatement::Savepoint(SavepointStatement { + savepoint_name: "savepoint_name".to_string(), + }))); + let mut parser = Parser::new(savepoint_tokens); + let result = parser.next_statement(); + assert_eq!(expected, result); + } + + #[test] + fn build_release_with_all_tokens_is_generated_correctly() { + // RELEASE SAVEPOINT savepoint_name; + let release_tokens = vec![ + token(TokenTypes::Release, "RELEASE"), + token(TokenTypes::Savepoint, "SAVEPOINT"), + token(TokenTypes::Identifier, "savepoint_name"), + token(TokenTypes::SemiColon, ";"), + ]; + let expected = Some(Ok(SqlStatement::Release(ReleaseStatement { + savepoint_name: "savepoint_name".to_string(), + }))); + let mut parser = Parser::new(release_tokens); + let result = parser.next_statement(); + assert_eq!(expected, result); + } +} \ No newline at end of file