Skip to content
25 changes: 19 additions & 6 deletions src/db/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::db::table::core::{row::Row, table::Table};
use crate::db::table::operations::{
alter_table, create_table, delete, drop_table, insert, select, update,
};
use crate::db::transactions::rollback::rollback_transaction_entry;
use crate::db::transactions::{TransactionEntry, TransactionLog};
use crate::interpreter::ast::SqlStatement;
use std::collections::HashMap;
Expand Down Expand Up @@ -81,10 +82,22 @@ impl Database {
Ok(None)
}
SqlStatement::Rollback(_) => {
self.transaction.commit_transaction()?;
self.tables.iter_mut().for_each(|(_, table)| {
table.rollback_transaction();
});
if let Some(transaction_log) = self.transaction.commit_transaction()?.entries {
// We roll back in reverse order because of dependencies.
for transaction_entry in transaction_log.iter().rev() {
match transaction_entry {
TransactionEntry::Statement(statement) => {
// TODO: Some matching needs to be here for table based operations.
// CURRENTLY SUPPORTED STATEMENTS ARE:
// - ALTER TABLE RENAME COLUMN, ALTER TABLE ADD COLUMN, ALTER TABLE DROP COLUMN
rollback_transaction_entry(self, &statement)?;
}
TransactionEntry::Savepoint(_) => {}
}
}
} else {
return Err("No transaction is currently active".to_string());
}
Ok(None)
}
SqlStatement::Savepoint(_) => {
Expand Down Expand Up @@ -159,13 +172,13 @@ mod tests {
let mut database = default_database();
let table = database.get_table("users");
assert!(table.is_ok());
assert_eq!("users", table.unwrap().name);
assert_eq!("users", table.unwrap().name().unwrap());
let table = database.get_table("not_users");
assert!(table.is_err());
assert_eq!("Table `not_users` does not exist", table.unwrap_err());
let table = database.get_table_mut("users");
assert!(table.is_ok());
assert_eq!("users", table.unwrap().name);
assert_eq!("users", table.unwrap().name().unwrap());
let table = database.get_table_mut("not_users");
assert!(table.is_err());
assert_eq!("Table `not_users` does not exist", table.unwrap_err());
Expand Down
12 changes: 2 additions & 10 deletions src/db/table/core/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ impl ColumnStack {
&mut self,
old_column_name: &String,
new_column_name: &String,
table_name: &String,
is_transaction: bool,
) -> Result<(), String> {
if is_transaction {
Expand All @@ -53,10 +52,7 @@ impl ColumnStack {
match columns {
Some(column) => column.name = new_column_name.clone(),
None => {
return Err(format!(
"Column `{}` does not exist in table `{}`",
old_column_name, table_name
));
return Err("Column does not exist".to_string());
}
}
Ok(())
Expand All @@ -65,7 +61,6 @@ impl ColumnStack {
pub fn drop_column(
&mut self,
column_name: &String,
table_name: &String,
is_transaction: bool,
) -> Result<(), String> {
if is_transaction {
Expand All @@ -74,10 +69,7 @@ impl ColumnStack {
match self.get_index_of_column(column_name) {
Ok(index) => self.peek_mut()?.remove(index),
Err(_) => {
return Err(format!(
"Column `{}` does not exist in table `{}`",
column_name, table_name
));
return Err("Column does not exist".to_string());
}
};
Ok(())
Expand Down
98 changes: 73 additions & 25 deletions src/db/table/core/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ use std::ops::{Index, IndexMut};

#[derive(Debug)]
pub struct Table {
pub name: String,
pub name: NameStack,
pub columns: ColumnStack,
rows: Vec<RowStack>,
}

#[derive(Debug)]
pub struct NameStack {
pub stack: Vec<String>,
}

impl Index<usize> for Table {
type Output = Row;

Expand All @@ -29,12 +34,27 @@ impl IndexMut<usize> for Table {
impl Table {
pub fn new(name: String, columns: Vec<ColumnDefinition>) -> Self {
Self {
name,
name: NameStack { stack: vec![name] },
columns: ColumnStack::new(columns),
rows: vec![],
}
}

pub fn name(&self) -> Result<&String, String> {
self.name
.stack
.last()
.ok_or("Error fetching table name.".to_string())
}

pub fn change_name(&mut self, new_name: String, is_transaction: bool) {
if is_transaction {
self.name.stack.push(new_name);
} else {
self.name.stack = vec![new_name];
}
}

pub fn get(&self, i: usize) -> Option<&Row> {
self.rows.get(i)?.stack.last()
}
Expand Down Expand Up @@ -106,60 +126,88 @@ impl Table {
Ok(())
}

pub fn rollback_transaction(&mut self) {
todo!()
pub fn rollback_columns(&mut self) {
self.columns.stack.pop();
}

pub fn get_column_from_row<'a>(&self, row: &'a Vec<Value>, column: &String) -> &'a Value {
pub fn rollback_all_rows(&mut self) {
for row_stack in self.rows.iter_mut() {
row_stack.stack.pop();
}
}

pub fn rollback_name(&mut self) {
self.name.stack.pop();
}

pub fn get_column_from_row<'a>(
&self,
row: &'a Vec<Value>,
column: &String,
) -> Result<&'a Value, String> {
for (i, value) in row.iter().enumerate() {
if self.get_column_names()[i] == column {
return &value;
if self.get_column_names()?[i] == column {
return Ok(&value);
}
}
return &Value::Null;
return Ok(&Value::Null);
}

pub fn has_column(&self, column: &String) -> bool {
self.get_columns().iter().any(|c| c.name == *column)
pub fn has_column(&self, column: &String) -> Result<bool, String> {
Ok(self.get_columns()?.iter().any(|c| c.name == *column))
}

pub fn width(&self) -> usize {
self.get_columns().len()
pub fn width(&self) -> Result<usize, String> {
Ok(self.get_columns()?.len())
}

pub fn get_index_of_column(&self, column: &String) -> Result<usize, String> {
for (i, c) in self.get_columns().iter().enumerate() {
for (i, c) in self.get_columns()?.iter().enumerate() {
if c.name == *column {
return Ok(i);
}
}
return Err(format!(
"Column {} does not exist in table {}",
column, self.name
column,
self.name()?
));
}

pub fn get_columns(&self) -> Vec<&ColumnDefinition> {
self.columns.stack.last().unwrap().iter().collect()
pub fn get_columns(&self) -> Result<Vec<&ColumnDefinition>, String> {
Ok(self
.columns
.stack
.last()
.ok_or("Column stack is empty".to_string())?
.iter()
.collect())
}

pub fn get_columns_mut(&mut self) -> Vec<&mut ColumnDefinition> {
self.columns.stack.last_mut().unwrap().iter_mut().collect()
pub fn get_columns_mut(&mut self) -> Result<Vec<&mut ColumnDefinition>, String> {
Ok(self
.columns
.stack
.last_mut()
.ok_or("Column stack is empty".to_string())?
.iter_mut()
.collect())
}

pub fn get_column_names(&self) -> Vec<&String> {
self.get_columns()
pub fn get_column_names(&self) -> Result<Vec<&String>, String> {
Ok(self
.get_columns()?
.iter()
.map(|column| &column.name)
.collect()
.collect())
}

pub fn push_column(&mut self, column: ColumnDefinition) {
self.columns.push_column(column, false);
pub fn push_column(&mut self, column: ColumnDefinition, is_transaction: bool) {
self.columns.push_column(column, is_transaction);
}

#[cfg(test)]
pub fn get_columns_clone(&self) -> Vec<ColumnDefinition> {
self.get_columns().iter().map(|c| (*c).clone()).collect()
pub fn get_columns_clone(&self) -> Result<Vec<ColumnDefinition>, String> {
Ok(self.get_columns()?.iter().map(|c| (*c).clone()).collect())
}
}
Loading
Loading