Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/db/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ impl Database {

fn get_table(&self, table_name: &str) -> Result<&Table, String> {
if !self.has_table(table_name) {
return Err(format!("Table {} does not exist", table_name));
return Err(format!("Table not found: {}", table_name));
}
Ok(self.tables.get(table_name).unwrap())
}

fn get_table_mut(&mut self, table_name: &str) -> Result<&mut Table, String> {
if !self.has_table(table_name) {
return Err(format!("Table {} does not exist", table_name));
return Err(format!("Table not found: {}", table_name));
}
Ok(self.tables.get_mut(table_name).unwrap())
}
Expand Down Expand Up @@ -162,12 +162,12 @@ mod tests {
assert_eq!(table.unwrap().name, "users");
let table = database.get_table("not_users");
assert!(table.is_err());
assert_eq!(table.unwrap_err(), "Table not_users does not exist");
assert_eq!(table.unwrap_err(), "Table not found: not_users");
let table = database.get_table_mut("users");
assert!(table.is_ok());
assert_eq!(table.unwrap().name, "users");
let table = database.get_table_mut("not_users");
assert!(table.is_err());
assert_eq!(table.unwrap_err(), "Table not_users does not exist");
assert_eq!(table.unwrap_err(), "Table not found: not_users");
}
}
3 changes: 2 additions & 1 deletion src/interpreter/ast/helpers/select_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{interpreter::{
ast::{
parser::Parser, SelectStatement, SelectStatementColumns, WhereStackElement,
helpers::{
common::{tokens_to_identifier_list, get_table_name},
common::{tokens_to_identifier_list, get_table_name, expect_token_type},
order_by_clause::get_order_by, where_stack::get_where_clause, limit_clause::get_limit
}
},
Expand All @@ -12,6 +12,7 @@ use crate::{interpreter::{
pub fn get_statement(parser: &mut Parser) -> Result<SelectStatement, String> {
parser.advance()?;
let columns = get_columns(parser)?;
expect_token_type(parser, TokenTypes::From)?;
let table_name = get_table_name(parser)?;
parser.advance()?;
let where_clause: Option<Vec<WhereStackElement>> = get_where_clause(parser)?;
Expand Down
159 changes: 105 additions & 54 deletions src/interpreter/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ mod helpers;
#[cfg(test)]
mod test_utils;

#[derive(Debug, PartialEq)]
pub struct DatabaseSqlStatement {
pub sql_statement: SqlStatement,
pub line_num: usize,
pub statement_text: String,
}

#[derive(Debug, PartialEq)]
pub enum SqlStatement {
CreateTable(CreateTableStatement),
Expand Down Expand Up @@ -235,34 +242,59 @@ impl StatementBuilder for DefaultStatementBuilder {
}
}

pub fn generate(tokens: Vec<Token>) -> Vec<Result<SqlStatement, String>> {
let mut results: Vec<Result<SqlStatement, String>> = vec![];
pub fn generate(tokens: Vec<Token>) -> Vec<Result<DatabaseSqlStatement, String>> {
let mut results: Vec<Result<DatabaseSqlStatement, String>> = vec![];
let mut parser = parser::Parser::new(tokens);
let builder : &dyn StatementBuilder = &DefaultStatementBuilder;
loop {
let line_num = match parser.line_num() {
Ok(line_num) => line_num,
Err(err) => {
results.push(Err(err));
break;
}
};
let next_statement = parser.next_statement(builder);
if let Some(next_statement) = next_statement {
if next_statement.is_err() {
loop {
if let Ok(token) = parser.current_token() {
if token.token_type != TokenTypes::EOF && token.token_type != TokenTypes::SemiColon {
let _ = parser.advance();
match next_statement {
Err(error) => {
results.push(Err(error));
// If we encountered a parsing error, skip until we find a semicolon or EOF
loop {
if let Ok(token) = parser.current_token() {
if token.token_type == TokenTypes::EOF {
break;
}
else if token.token_type == TokenTypes::SemiColon {
let _ = parser.advance_past_semicolon();
break;
}
else {
if parser.advance().is_err() {
return results;
}
}
}
else {
break;
}
}
else {
break;
}
Ok(sql_statement) => {
let parser_advance_result = parser.advance_past_semicolon();
if parser_advance_result.is_err() {
results.push(Err(parser_advance_result.err().unwrap()));
return results;
}
results.push(
Ok(DatabaseSqlStatement {
sql_statement: sql_statement,
line_num: line_num,
statement_text: "".to_string(),
})
);
}
}
let parser_advance_result = parser.advance_past_semicolon();
if parser_advance_result.is_err() {
results.push(Err(parser_advance_result.err().unwrap()));
return results;
}
results.push(next_statement);
} else {
break;
}
Expand Down Expand Up @@ -319,22 +351,30 @@ mod tests {
assert!(result[0].is_ok());
assert!(result[1].is_ok());
let expected = vec![
Ok(SqlStatement::Select(SelectStatementStack {
elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
})],
})),
Ok(SqlStatement::InsertInto(InsertIntoStatement {
table_name: "users".to_string(),
columns: None,
values: vec![
vec![Value::Integer(1), Value::Text("Alice".to_string())],
],
})),
Ok(DatabaseSqlStatement {
sql_statement: SqlStatement::Select(SelectStatementStack {
elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
})],
}),
line_num: 1,
statement_text: "".to_string(),
}),
Ok(DatabaseSqlStatement {
sql_statement: SqlStatement::InsertInto(InsertIntoStatement {
table_name: "users".to_string(),
columns: None,
values: vec![
vec![Value::Integer(1), Value::Text("Alice".to_string())],
],
}),
line_num: 1,
statement_text: "".to_string(),
}),
];
assert_eq!(expected, result);
}
Expand Down Expand Up @@ -363,14 +403,17 @@ mod tests {
assert!(result[1].is_ok());
let expected = vec![
Err("Error at line 1, column 0: Unexpected value: ;".to_string()),
Ok(SqlStatement::InsertInto(InsertIntoStatement {

table_name: "users".to_string(),
columns: None,
values: vec![
vec![Value::Integer(1), Value::Text("Alice".to_string())],
],
})),
Ok(DatabaseSqlStatement {
sql_statement: SqlStatement::InsertInto(InsertIntoStatement {
table_name: "users".to_string(),
columns: None,
values: vec![
vec![Value::Integer(1), Value::Text("Alice".to_string())],
],
}),
line_num: 1,
statement_text: "".to_string(),
}),
];
assert_eq!(expected, result);
}
Expand Down Expand Up @@ -399,22 +442,30 @@ mod tests {
assert!(result[0].is_ok());
assert!(result[1].is_ok());
let expected = vec![
Ok(SqlStatement::Select(SelectStatementStack {
elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement {
Ok(DatabaseSqlStatement {
sql_statement: SqlStatement::Select(SelectStatementStack {
elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
})],
}),
line_num: 1,
statement_text: "".to_string(),
}),
Ok(DatabaseSqlStatement {
sql_statement: SqlStatement::InsertInto(InsertIntoStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
})],
})),
Ok(SqlStatement::InsertInto(InsertIntoStatement {
table_name: "users".to_string(),
columns: None,
values: vec![
vec![Value::Integer(1), Value::Text("Alice".to_string())],
],
})),
columns: None,
values: vec![
vec![Value::Integer(1), Value::Text("Alice".to_string())],
],
}),
line_num: 1,
statement_text: "".to_string(),
}),
];
assert_eq!(expected, result);
}
Expand Down
4 changes: 4 additions & 0 deletions src/interpreter/ast/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ impl<'a> Parser<'a> {
current: 0,
};
}

pub fn line_num(&self) -> Result<usize, String> {
return Ok(self.current_token()?.line_num);
}

pub fn current_token(&self) -> Result<&Token<'a>, String> {
if self.current >= self.tokens.len() {
Expand Down
24 changes: 13 additions & 11 deletions src/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@ pub fn run_sql(database: &mut db::database::Database, sql: &str) -> Vec<Result<O
// println!("{:?}", sql_statement);
match sql_statement {
Ok(statement) => {
let result = database.execute(statement);
if let Ok(values) = result {
if let Some(rows) = values {
sql_results.push(Ok(Some(rows)));
let result = database.execute(statement.sql_statement);
match result {
Ok(values) => {
if let Some(rows) = values {
sql_results.push(Ok(Some(rows)));
}
else {
sql_results.push(Ok(None));
}
}
else {
sql_results.push(Ok(None));
Err(error) => {
sql_results.push(Err(format!("Execution Error with statement starting on line {} \n Error: {}", statement.line_num, error)));
}
}
else {
sql_results.push(Err(result.unwrap_err()));
}
},
Err(error) => {
sql_results.push(Err(error));
Err(parser_error) => {
sql_results.push(Err(format!("Parsing Error: {}", parser_error)));
},
}
}
Expand Down
35 changes: 35 additions & 0 deletions tests/crud_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,39 @@ fn test_complex_statements_crud() {
assert!(result.pop().unwrap().unwrap().is_none());
assert_eq!(expected_first, result.pop().unwrap().unwrap().unwrap());
assert!(result.into_iter().all(|result| result.is_ok() && result.unwrap().is_none()));
}

#[test]
fn test_parsing_errors() {
let mut database = Database::new();
let sql = "
CREATE TABLE abc (
id hello,
name TEXT,
age INTEGER,
money REAL
);
SELECT * FROM users wherea;
SELECT * users;
";
let result = run_sql(&mut database, sql);
assert!(result.iter().all(|result| result.is_err()));
let expected = vec![
Err("Parsing Error: Error at line 3, column 11: Unexpected value: hello".to_string()),
Err("Parsing Error: Error at line 8, column 24: Unexpected value: wherea".to_string()),
Err("Parsing Error: Error at line 9, column 13: Unexpected value: users".to_string()),
];
assert_eq!(expected, result);
}

#[test]
fn test_execution_errors() {
let mut database = Database::new();
let sql = "
SELECT * FROM users WHERE id = 'hello';
";
let result = run_sql(&mut database, sql);
assert!(result.iter().all(|result| result.is_err()));
let expected = vec![Err("Execution Error with statement starting on line 2 \n Error: Table not found: users".to_string())];
assert_eq!(expected, result);
}
Loading