Skip to content
18 changes: 14 additions & 4 deletions src/db/database.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::db::table::{Table, Value};
use crate::interpreter::ast::{SqlStatement, CreateTableStatement, InsertIntoStatement, SelectStatement, DeleteStatement, UpdateStatement};
use crate::interpreter::ast::{SqlStatement, CreateTableStatement, InsertIntoStatement, SelectStatement, DeleteStatement, UpdateStatement, SelectStatementStackElement};
use crate::db::table::select;
use crate::db::table::insert;
use crate::db::table::delete;
Expand Down Expand Up @@ -27,9 +27,19 @@ impl Database {
self.insert_into_table(statement)?;
Ok(None)
},
SqlStatement::Select(statement) => {
let rows = self.select_from_table(statement)?;
Ok(Some(rows))
SqlStatement::Select(mut statement) => {
let select_statement = statement.elements.pop();
if let Some(select_statement) = select_statement {
match select_statement {
SelectStatementStackElement::SelectStatement(select_statement) => {
let rows = self.select_from_table(select_statement)?;
Ok(Some(rows))
}
_ => Err(format!("Expected select statement, got {:?}", select_statement)),
}
} else {
Ok(None)
}
},
SqlStatement::UpdateStatement(statement) => {
self.update_table(statement)?;
Expand Down
3 changes: 2 additions & 1 deletion src/interpreter/ast/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod where_stack;
pub mod where_condition;
pub mod order_by_clause;
pub mod limit_clause;
pub mod common;
pub mod common;
pub mod select_statement;
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use crate::{interpreter::{
ast::{
parser::Parser, SelectStatement, SelectStatementColumns, SqlStatement, WhereStackElement,
parser::Parser, SelectStatement, SelectStatementColumns, WhereStackElement,
helpers::{
common::{expect_token_type, tokens_to_identifier_list, get_table_name},
common::{tokens_to_identifier_list, get_table_name},
order_by_clause::get_order_by, where_stack::get_where_clause, limit_clause::get_limit
}
},
tokenizer::token::TokenTypes
}};

pub fn build(parser: &mut Parser) -> Result<SqlStatement, String> {
pub fn get_statement(parser: &mut Parser) -> Result<SelectStatement, String> {
parser.advance()?;
let columns = get_columns(parser)?;
let table_name = get_table_name(parser)?;
Expand All @@ -18,15 +18,13 @@ pub fn build(parser: &mut Parser) -> Result<SqlStatement, String> {
let order_by_clause = get_order_by(parser)?;
let limit_clause = get_limit(parser)?;

// Ensure SemiColon
expect_token_type(parser, TokenTypes::SemiColon)?;
return Ok(SqlStatement::Select(SelectStatement {
table_name: table_name,
columns: columns,
where_clause: where_clause,
order_by_clause: order_by_clause,
limit_clause: limit_clause,
}));
return Ok(SelectStatement {
table_name: table_name,
columns: columns,
where_clause: where_clause,
order_by_clause: order_by_clause,
limit_clause: limit_clause,
});
}

fn get_columns(parser: &mut Parser) -> Result<SelectStatementColumns, String> {
Expand Down Expand Up @@ -65,16 +63,16 @@ mod tests {
token(TokenTypes::SemiColon, ";"),
];
let mut parser = Parser::new(tokens);
let result = build(&mut parser);
let result = get_statement(&mut parser);
assert!(result.is_ok());
let statement = result.unwrap();
assert_eq!(statement, SqlStatement::Select(SelectStatement {
assert_eq!(statement, SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
}));
});
}

#[test]
Expand All @@ -88,18 +86,18 @@ mod tests {
token(TokenTypes::SemiColon, ";"),
];
let mut parser = Parser::new(tokens);
let result = build(&mut parser);
let result = get_statement(&mut parser);
assert!(result.is_ok());
let statement = result.unwrap();
assert_eq!(statement, SqlStatement::Select(SelectStatement {
assert_eq!(statement, SelectStatement {
table_name: "guests".to_string(),
columns: SelectStatementColumns::Specific(vec![
"id".to_string(),
]),
where_clause: None,
order_by_clause: None,
limit_clause: None,
}));
});
}

#[test]
Expand All @@ -115,10 +113,10 @@ mod tests {
token(TokenTypes::SemiColon, ";"),
];
let mut parser = Parser::new(tokens);
let result = build(&mut parser);
let result = get_statement(&mut parser);
assert!(result.is_ok());
let statement = result.unwrap();
assert_eq!(statement, SqlStatement::Select(SelectStatement {
assert_eq!(statement, SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::Specific(vec![
"id".to_string(),
Expand All @@ -127,7 +125,7 @@ mod tests {
where_clause: None,
order_by_clause: None,
limit_clause: None,
}));
});
}

#[test]
Expand Down Expand Up @@ -158,10 +156,10 @@ mod tests {
token(TokenTypes::SemiColon, ";"),
];
let mut parser = Parser::new(tokens);
let result = build(&mut parser);
let result = get_statement(&mut parser);
assert!(result.is_ok());
let statement = result.unwrap();
assert_eq!(statement, SqlStatement::Select(SelectStatement {
let expected = SelectStatement {
table_name: "guests".to_string(),
columns: SelectStatementColumns::Specific(vec![
"id".to_string(),
Expand Down Expand Up @@ -191,6 +189,7 @@ mod tests {
limit: Value::Integer(10),
offset: Some(Value::Integer(5)),
}),
}));
};
assert_eq!(expected, statement);
}
}
31 changes: 23 additions & 8 deletions src/interpreter/ast/helpers/where_stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub fn get_where_clause(parser: &mut Parser) -> Result<Option<Vec<WhereStackElem
let mut expected_next_token = WhereClauseExpectedNextToken::ConditionLeftParenNot;

loop {
let where_condition = get_where_condition(parser)?;
let where_condition = get_where_condition(parser, &operator_stack)?;
match where_condition {
Some(where_stack_element) => {
match where_stack_element {
Expand Down Expand Up @@ -131,7 +131,7 @@ pub fn get_where_clause(parser: &mut Parser) -> Result<Option<Vec<WhereStackElem
Ok(Some(where_stack))
}

fn get_where_condition(parser: &mut Parser) -> Result<Option<WhereStackElement>, String> {
fn get_where_condition(parser: &mut Parser, operator_stack: &Vec<WhereStackOperators>) -> Result<Option<WhereStackElement>, String> {
let token = parser.current_token()?;
match token.token_type {
// Logical operators and parentheses
Expand All @@ -152,8 +152,18 @@ fn get_where_condition(parser: &mut Parser) -> Result<Option<WhereStackElement>,
return Ok(Some(WhereStackElement::Parentheses(Parentheses::Left)))
},
TokenTypes::RightParen => {
parser.advance()?;
return Ok(Some(WhereStackElement::Parentheses(Parentheses::Right)))
// TODO improve this check.
let has_matching_left_paren = operator_stack.iter().any(|op| {
matches!(op, WhereStackOperators::Parentheses(Parentheses::Left))
});

if has_matching_left_paren {
parser.advance()?;
return Ok(Some(WhereStackElement::Parentheses(Parentheses::Right)))
} else {
// We may have a mismatched parenthesis from the UNION STATEMENTs causing this.
return Ok(None);
}
},
// Conditions
TokenTypes::Identifier | TokenTypes::IntLiteral | TokenTypes::RealLiteral | TokenTypes::String | TokenTypes::Blob | TokenTypes::Null => {
Expand Down Expand Up @@ -398,7 +408,8 @@ mod tests {
}

#[test]
fn returns_error_for_extra_closing_parenthesis() {
fn does_not_return_error_for_extra_closing_parenthesis() {
// This extra closing parenthesis could be from the UNION STATEMENTs and would error on that level.
// WHERE (id = 1 OR name = "John")); (extra closing parenthesis)
let tokens = vec![
token(TokenTypes::Where, "WHERE"),
Expand All @@ -416,9 +427,13 @@ mod tests {
];
let mut parser = Parser::new(tokens);
let result = get_where_clause(&mut parser);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Mismatched parentheses found.");
assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::SemiColon);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(vec![
simple_condition("id", Operator::Equals, Value::Integer(1)),
simple_condition("name", Operator::Equals, Value::Text("John".to_string())),
WhereStackElement::LogicalOperator(LogicalOperator::Or),
]));
assert_eq!(parser.current_token().unwrap().token_type, TokenTypes::RightParen);
}

#[test]
Expand Down
69 changes: 54 additions & 15 deletions src/interpreter/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::db::table::{ColumnDefinition, Value};
mod create_statement;
mod insert_statement;
mod parser;
mod select_statement;
mod select_statement_stack;
mod update_statement;
mod delete_statement;
mod helpers;
Expand All @@ -15,7 +15,7 @@ mod test_utils;
pub enum SqlStatement {
CreateTable(CreateTableStatement),
InsertInto(InsertIntoStatement),
Select(SelectStatement),
Select(SelectStatementStack),
UpdateStatement(UpdateStatement),
DeleteStatement(DeleteStatement),
}
Expand All @@ -33,6 +33,41 @@ pub struct InsertIntoStatement {
pub values: Vec<Vec<Value>>,
}

#[derive(Debug, PartialEq)]
pub struct SelectStatementStack {
pub elements: Vec<SelectStatementStackElement>,
}

#[derive(Debug, PartialEq)]
pub enum SelectStatementStackElement {
SelectStatement(SelectStatement),
SetOperator(SetOperator),
}

#[derive(Debug, PartialEq)]
pub enum SelectStackOperators {
SetOperator(SetOperator),
Parentheses(Parentheses),
}

#[derive(Debug, PartialEq)]
pub enum SetOperator {
Union,
UnionAll,
Intersect,
Except,
}

impl SetOperator {
pub fn is_greater_precedence(&self, other: &SetOperator) -> bool {
match (self, other) {
(SetOperator::Intersect, SetOperator::Intersect) => false,
(SetOperator::Intersect, _) => true,
(_, _) => false,
}
}
}

#[derive(Debug, PartialEq)]
pub struct SelectStatement {
pub table_name: String,
Expand Down Expand Up @@ -188,7 +223,7 @@ impl StatementBuilder for DefaultStatementBuilder {
}

fn build_select(&self, parser: &mut parser::Parser) -> Result<SqlStatement, String> {
select_statement::build(parser)
select_statement_stack::build(parser)
}

fn build_update(&self, parser: &mut parser::Parser) -> Result<SqlStatement, String> {
Expand Down Expand Up @@ -284,12 +319,14 @@ mod tests {
assert!(result[0].is_ok());
assert!(result[1].is_ok());
let expected = vec![
Ok(SqlStatement::Select(SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
Ok(SqlStatement::Select(SelectStatementStack {
elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
})],
})),
Ok(SqlStatement::InsertInto(InsertIntoStatement {
table_name: "users".to_string(),
Expand Down Expand Up @@ -362,12 +399,14 @@ mod tests {
assert!(result[0].is_ok());
assert!(result[1].is_ok());
let expected = vec![
Ok(SqlStatement::Select(SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
Ok(SqlStatement::Select(SelectStatementStack {
elements: vec![SelectStatementStackElement::SelectStatement(SelectStatement {
table_name: "users".to_string(),
columns: SelectStatementColumns::All,
where_clause: None,
order_by_clause: None,
limit_clause: None,
})],
})),
Ok(SqlStatement::InsertInto(InsertIntoStatement {
table_name: "users".to_string(),
Expand Down
Loading
Loading