diff --git a/src/db/database.rs b/src/db/database.rs index da48369..430af8e 100644 --- a/src/db/database.rs +++ b/src/db/database.rs @@ -91,14 +91,14 @@ impl Database { fn get_table(&self, table_name: &str) -> Result<&Table, String> { if !self.has_table(table_name) { - return Err(format!("Table {} does not exist", table_name)); + return Err(format!("Table not found: {}", table_name)); } Ok(self.tables.get(table_name).unwrap()) } fn get_table_mut(&mut self, table_name: &str) -> Result<&mut Table, String> { if !self.has_table(table_name) { - return Err(format!("Table {} does not exist", table_name)); + return Err(format!("Table not found: {}", table_name)); } Ok(self.tables.get_mut(table_name).unwrap()) } @@ -162,12 +162,12 @@ mod tests { assert_eq!(table.unwrap().name, "users"); let table = database.get_table("not_users"); assert!(table.is_err()); - assert_eq!(table.unwrap_err(), "Table not_users does not exist"); + assert_eq!(table.unwrap_err(), "Table not found: not_users"); let table = database.get_table_mut("users"); assert!(table.is_ok()); assert_eq!(table.unwrap().name, "users"); let table = database.get_table_mut("not_users"); assert!(table.is_err()); - assert_eq!(table.unwrap_err(), "Table not_users does not exist"); + assert_eq!(table.unwrap_err(), "Table not found: not_users"); } } \ No newline at end of file diff --git a/src/interpreter/ast/helpers/select_statement.rs b/src/interpreter/ast/helpers/select_statement.rs index 7c708cf..8684483 100644 --- a/src/interpreter/ast/helpers/select_statement.rs +++ b/src/interpreter/ast/helpers/select_statement.rs @@ -2,7 +2,7 @@ use crate::{interpreter::{ ast::{ parser::Parser, SelectStatement, SelectStatementColumns, WhereStackElement, helpers::{ - common::{tokens_to_identifier_list, get_table_name}, + common::{tokens_to_identifier_list, get_table_name, expect_token_type}, order_by_clause::get_order_by, where_stack::get_where_clause, limit_clause::get_limit } }, @@ -12,6 +12,7 @@ use crate::{interpreter::{ pub fn get_statement(parser: &mut Parser) -> Result { parser.advance()?; let columns = get_columns(parser)?; + expect_token_type(parser, TokenTypes::From)?; let table_name = get_table_name(parser)?; parser.advance()?; let where_clause: Option> = get_where_clause(parser)?; diff --git a/src/interpreter/ast/mod.rs b/src/interpreter/ast/mod.rs index d401a9e..d916488 100644 --- a/src/interpreter/ast/mod.rs +++ b/src/interpreter/ast/mod.rs @@ -11,6 +11,13 @@ mod helpers; #[cfg(test)] mod test_utils; +#[derive(Debug, PartialEq)] +pub struct DatabaseSqlStatement { + pub sql_statement: SqlStatement, + pub line_num: usize, + pub statement_text: String, +} + #[derive(Debug, PartialEq)] pub enum SqlStatement { CreateTable(CreateTableStatement), @@ -235,34 +242,59 @@ impl StatementBuilder for DefaultStatementBuilder { } } -pub fn generate(tokens: Vec) -> Vec> { - let mut results: Vec> = vec![]; +pub fn generate(tokens: Vec) -> Vec> { + let mut results: Vec> = vec![]; let mut parser = parser::Parser::new(tokens); let builder : &dyn StatementBuilder = &DefaultStatementBuilder; loop { + let line_num = match parser.line_num() { + Ok(line_num) => line_num, + Err(err) => { + results.push(Err(err)); + break; + } + }; let next_statement = parser.next_statement(builder); if let Some(next_statement) = next_statement { - if next_statement.is_err() { - loop { - if let Ok(token) = parser.current_token() { - if token.token_type != TokenTypes::EOF && token.token_type != TokenTypes::SemiColon { - let _ = parser.advance(); + match next_statement { + Err(error) => { + results.push(Err(error)); + // If we encountered a parsing error, skip until we find a semicolon or EOF + loop { + if let Ok(token) = parser.current_token() { + if token.token_type == TokenTypes::EOF { + break; + } + else if token.token_type == TokenTypes::SemiColon { + let _ = parser.advance_past_semicolon(); + break; + } + else { + if parser.advance().is_err() { + return results; + } + } } else { break; } } - else { - break; + } + Ok(sql_statement) => { + let parser_advance_result = parser.advance_past_semicolon(); + if parser_advance_result.is_err() { + results.push(Err(parser_advance_result.err().unwrap())); + return results; } + results.push( + Ok(DatabaseSqlStatement { + sql_statement: sql_statement, + line_num: line_num, + statement_text: "".to_string(), + }) + ); } } - let parser_advance_result = parser.advance_past_semicolon(); - if parser_advance_result.is_err() { - results.push(Err(parser_advance_result.err().unwrap())); - return results; - } - results.push(next_statement); } else { break; } @@ -319,22 +351,30 @@ mod tests { assert!(result[0].is_ok()); assert!(result[1].is_ok()); let expected = vec![ - Ok(SqlStatement::Select(SelectStatementStack { - elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { - table_name: "users".to_string(), - columns: SelectStatementColumns::All, - where_clause: None, - order_by_clause: None, - limit_clause: None, - })], - })), - Ok(SqlStatement::InsertInto(InsertIntoStatement { - table_name: "users".to_string(), - columns: None, - values: vec![ - vec![Value::Integer(1), Value::Text("Alice".to_string())], - ], - })), + Ok(DatabaseSqlStatement { + sql_statement: SqlStatement::Select(SelectStatementStack { + elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { + table_name: "users".to_string(), + columns: SelectStatementColumns::All, + where_clause: None, + order_by_clause: None, + limit_clause: None, + })], + }), + line_num: 1, + statement_text: "".to_string(), + }), + Ok(DatabaseSqlStatement { + sql_statement: SqlStatement::InsertInto(InsertIntoStatement { + table_name: "users".to_string(), + columns: None, + values: vec![ + vec![Value::Integer(1), Value::Text("Alice".to_string())], + ], + }), + line_num: 1, + statement_text: "".to_string(), + }), ]; assert_eq!(expected, result); } @@ -363,14 +403,17 @@ mod tests { assert!(result[1].is_ok()); let expected = vec![ Err("Error at line 1, column 0: Unexpected value: ;".to_string()), - Ok(SqlStatement::InsertInto(InsertIntoStatement { - - table_name: "users".to_string(), - columns: None, - values: vec![ - vec![Value::Integer(1), Value::Text("Alice".to_string())], - ], - })), + Ok(DatabaseSqlStatement { + sql_statement: SqlStatement::InsertInto(InsertIntoStatement { + table_name: "users".to_string(), + columns: None, + values: vec![ + vec![Value::Integer(1), Value::Text("Alice".to_string())], + ], + }), + line_num: 1, + statement_text: "".to_string(), + }), ]; assert_eq!(expected, result); } @@ -399,22 +442,30 @@ mod tests { assert!(result[0].is_ok()); assert!(result[1].is_ok()); let expected = vec![ - Ok(SqlStatement::Select(SelectStatementStack { - elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { + Ok(DatabaseSqlStatement { + sql_statement: SqlStatement::Select(SelectStatementStack { + elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement { + table_name: "users".to_string(), + columns: SelectStatementColumns::All, + where_clause: None, + order_by_clause: None, + limit_clause: None, + })], + }), + line_num: 1, + statement_text: "".to_string(), + }), + Ok(DatabaseSqlStatement { + sql_statement: SqlStatement::InsertInto(InsertIntoStatement { table_name: "users".to_string(), - columns: SelectStatementColumns::All, - where_clause: None, - order_by_clause: None, - limit_clause: None, - })], - })), - Ok(SqlStatement::InsertInto(InsertIntoStatement { - table_name: "users".to_string(), - columns: None, - values: vec![ - vec![Value::Integer(1), Value::Text("Alice".to_string())], - ], - })), + columns: None, + values: vec![ + vec![Value::Integer(1), Value::Text("Alice".to_string())], + ], + }), + line_num: 1, + statement_text: "".to_string(), + }), ]; assert_eq!(expected, result); } diff --git a/src/interpreter/ast/parser.rs b/src/interpreter/ast/parser.rs index 3fc0be5..d88fabd 100644 --- a/src/interpreter/ast/parser.rs +++ b/src/interpreter/ast/parser.rs @@ -15,6 +15,10 @@ impl<'a> Parser<'a> { current: 0, }; } + + pub fn line_num(&self) -> Result { + return Ok(self.current_token()?.line_num); + } pub fn current_token(&self) -> Result<&Token<'a>, String> { if self.current >= self.tokens.len() { diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index 012de04..ac9c87f 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -13,21 +13,23 @@ pub fn run_sql(database: &mut db::database::Database, sql: &str) -> Vec { - let result = database.execute(statement); - if let Ok(values) = result { - if let Some(rows) = values { - sql_results.push(Ok(Some(rows))); + let result = database.execute(statement.sql_statement); + match result { + Ok(values) => { + if let Some(rows) = values { + sql_results.push(Ok(Some(rows))); + } + else { + sql_results.push(Ok(None)); + } } - else { - sql_results.push(Ok(None)); + Err(error) => { + sql_results.push(Err(format!("Execution Error with statement starting on line {} \n Error: {}", statement.line_num, error))); } } - else { - sql_results.push(Err(result.unwrap_err())); - } }, - Err(error) => { - sql_results.push(Err(error)); + Err(parser_error) => { + sql_results.push(Err(format!("Parsing Error: {}", parser_error))); }, } } diff --git a/tests/crud_test.rs b/tests/crud_test.rs index 948294e..f8a75a3 100644 --- a/tests/crud_test.rs +++ b/tests/crud_test.rs @@ -64,4 +64,39 @@ fn test_complex_statements_crud() { assert!(result.pop().unwrap().unwrap().is_none()); assert_eq!(expected_first, result.pop().unwrap().unwrap().unwrap()); assert!(result.into_iter().all(|result| result.is_ok() && result.unwrap().is_none())); +} + +#[test] +fn test_parsing_errors() { + let mut database = Database::new(); + let sql = " + CREATE TABLE abc ( + id hello, + name TEXT, + age INTEGER, + money REAL + ); + SELECT * FROM users wherea; + SELECT * users; + "; + let result = run_sql(&mut database, sql); + assert!(result.iter().all(|result| result.is_err())); + let expected = vec![ + Err("Parsing Error: Error at line 3, column 11: Unexpected value: hello".to_string()), + Err("Parsing Error: Error at line 8, column 24: Unexpected value: wherea".to_string()), + Err("Parsing Error: Error at line 9, column 13: Unexpected value: users".to_string()), + ]; + assert_eq!(expected, result); +} + +#[test] +fn test_execution_errors() { + let mut database = Database::new(); + let sql = " + SELECT * FROM users WHERE id = 'hello'; + "; + let result = run_sql(&mut database, sql); + assert!(result.iter().all(|result| result.is_err())); + let expected = vec![Err("Execution Error with statement starting on line 2 \n Error: Table not found: users".to_string())]; + assert_eq!(expected, result); } \ No newline at end of file