diff --git a/src/binder/create.rs b/src/binder/create.rs index 8b137891..902c9cf3 100644 --- a/src/binder/create.rs +++ b/src/binder/create.rs @@ -1 +1,55 @@ +use super::Binder; +use crate::binder::{lower_case_name, split_name}; +use crate::catalog::{Column, ColumnDesc}; +use crate::planner::logical_create_table_plan::LogicalCreateTablePlan; +use crate::planner::LogicalPlan; +use crate::types::ColumnId; +use anyhow::Result; +use sqlparser::ast::{ColumnDef, ObjectName}; +use std::collections::HashSet; +impl Binder { + pub(super) fn bind_create_table( + &mut self, + name: ObjectName, + columns: &[ColumnDef], + ) -> Result { + let name = lower_case_name(&name); + + let (_, table_name) = split_name(&name)?; + + let table = self + .context + .catalog + .get_table_by_name(table_name) + .ok_or_else(|| { + anyhow::Error::msg(format!("table {} not found", table_name.to_string())) + })?; + + // check duplicated column names + let mut set = HashSet::new(); + for col in columns.iter() { + if !set.insert(col.name.value.clone()) { + return Err(anyhow::Error::msg(format!( + "bind duplicated column {}", + col.name.value.clone() + ))); + } + } + + let mut columns: Vec = columns + .iter() + .enumerate() + .map(|(_, col)| Column::from(col)) + .collect(); + + let plan = LogicalCreateTablePlan { + table_name: table_name.to_string(), + columns: columns + .into_iter() + .map(|col| (col.name.to_string(), col.desc.clone())) + .collect(), + }; + Ok(plan) + } +} diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 38e66420..aecec08a 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -5,15 +5,12 @@ mod select; use std::collections::HashMap; -use crate::{ - catalog::CatalogRef, - expression::ScalarExpression, - planner::LogicalPlan, -}; +use crate::{catalog::CatalogRef, expression::ScalarExpression, planner::LogicalPlan}; -use anyhow::Result; -use sqlparser::ast::Statement; +use crate::catalog::DEFAULT_SCHEMA_NAME; use crate::types::TableId; +use anyhow::Result; +use sqlparser::ast::{Ident, ObjectName, Statement}; pub struct BinderContext { catalog: CatalogRef, @@ -66,8 +63,31 @@ impl Binder { let plan = self.bind_query(query)?; LogicalPlan::Select(plan) } + Statement::CreateTable { name, columns, .. } => { + let plan = self.bind_create_table(name.to_owned(), &columns)?; + LogicalPlan::CreateTable(plan) + } _ => unimplemented!(), }; Ok(plan) } } + +/// Convert an object name into lower case +fn lower_case_name(name: &ObjectName) -> ObjectName { + ObjectName( + name.0 + .iter() + .map(|ident| Ident::new(ident.value.to_lowercase())) + .collect(), + ) +} + +/// Split an object name into `(schema name, table name)`. +fn split_name(name: &ObjectName) -> Result<(&str, &str)> { + Ok(match name.0.as_slice() { + [table] => (DEFAULT_SCHEMA_NAME, &table.value), + [schema, table] => (&schema.value, &table.value), + _ => return Err(anyhow::anyhow!("Invalid table name: {:?}", name)), + }) +} diff --git a/src/binder/select.rs b/src/binder/select.rs index efec0f6d..89e97c5c 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -16,13 +16,13 @@ use crate::{ use super::Binder; +use crate::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME}; use anyhow::Result; use itertools::Itertools; use sqlparser::ast::{ Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins, }; -use crate::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME}; impl Binder { pub(super) fn bind_query(&mut self, query: &Query) -> Result { @@ -128,7 +128,8 @@ impl Binder { .map(|ident| Ident::new(ident.value.to_lowercase())) .collect_vec(); - let (_database, _schema, mut table): (&str, &str, &str) = match obj_name.as_slice() { + let (_database, _schema, mut table): (&str, &str, &str) = match obj_name.as_slice() + { [table] => (DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, &table.value), [schema, table] => (DEFAULT_DATABASE_NAME, &schema.value, &table.value), [database, schema, table] => (&database.value, &schema.value, &table.value), diff --git a/src/catalog/column.rs b/src/catalog/column.rs index c94f0a51..7e6aab37 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -1,4 +1,5 @@ use crate::types::{ColumnId, DataType, IdGenerator}; +use sqlparser::ast::{ColumnDef, ColumnOption}; #[derive(Clone)] pub struct Column { @@ -8,10 +9,7 @@ pub struct Column { } impl Column { - pub(crate) fn new( - column_name: String, - column_desc: ColumnDesc, - ) -> Column { + pub(crate) fn new(column_name: String, column_desc: ColumnDesc) -> Column { Column { id: IdGenerator::build(), name: column_name, @@ -63,6 +61,28 @@ impl ColumnDesc { } } +impl From<&ColumnDef> for Column { + fn from(cdef: &ColumnDef) -> Self { + let mut is_nullable = true; + let mut is_primary_ = false; + for opt in &cdef.options { + match opt.option { + ColumnOption::Null => is_nullable = true, + ColumnOption::NotNull => is_nullable = false, + ColumnOption::Unique { is_primary } => is_primary_ = is_primary, + _ => todo!("column options"), + } + } + Column::new( + cdef.name.value.clone(), + ColumnDesc::new( + DataType::new(cdef.data_type.clone(), is_nullable), + is_primary_, + ), + ) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/catalog/root.rs b/src/catalog/root.rs index f139df64..b0de3747 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -10,7 +10,10 @@ pub struct Root { impl Root { #[allow(dead_code)] pub(crate) fn new() -> Self { - Root { table_idxs: Default::default(), tables: Default::default() } + Root { + table_idxs: Default::default(), + tables: Default::default(), + } } pub(crate) fn get_table_id_by_name(&self, name: &str) -> Option { @@ -21,7 +24,16 @@ impl Root { self.tables.get(&table_id) } - pub(crate) fn add_table(&mut self, table_name: String, columns: Vec) -> Result { + pub(crate) fn get_table_by_name(&self, name: &str) -> Option<&Table> { + let id = self.table_idxs.get(name)?; + self.tables.get(id) + } + + pub(crate) fn add_table( + &mut self, + table_name: String, + columns: Vec, + ) -> Result { if self.table_idxs.contains_key(&table_name) { return Err(CatalogError::Duplicated("column", table_name)); } diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 38987a2e..d3f0b829 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -1,7 +1,7 @@ -use std::collections::{BTreeMap, HashMap}; -use itertools::Itertools; use crate::catalog::{CatalogError, Column}; use crate::types::{ColumnId, IdGenerator, TableId}; +use itertools::Itertools; +use std::collections::{BTreeMap, HashMap}; pub struct Table { pub id: TableId, @@ -25,7 +25,8 @@ impl Table { } pub(crate) fn get_all_columns(&self) -> Vec<(ColumnId, &Column)> { - self.columns.iter() + self.columns + .iter() .map(|(col_id, col)| (*col_id, col)) .collect_vec() } @@ -33,10 +34,7 @@ impl Table { /// Add a column to the table catalog. pub(crate) fn add_column(&mut self, col_catalog: Column) -> Result { if self.column_idxs.contains_key(&col_catalog.name) { - return Err(CatalogError::Duplicated( - "column", - col_catalog.name.into(), - )); + return Err(CatalogError::Duplicated("column", col_catalog.name.into())); } let col_id = col_catalog.id; @@ -74,19 +72,10 @@ mod tests { // | 1 | true | // | 2 | false | fn test_table_catalog() { - let col0 = Column::new( - "a".into(), - DataTypeKind::Int(None).not_null().to_column(), - ); - let col1 = Column::new( - "b".into(), - DataTypeKind::Boolean.not_null().to_column() - ); + let col0 = Column::new("a".into(), DataTypeKind::Int(None).not_null().to_column()); + let col1 = Column::new("b".into(), DataTypeKind::Boolean.not_null().to_column()); let col_catalogs = vec![col0, col1]; - let table_catalog = Table::new( - "test".to_string(), - col_catalogs - ).unwrap(); + let table_catalog = Table::new("test".to_string(), col_catalogs).unwrap(); assert_eq!(table_catalog.contains_column("a"), true); assert_eq!(table_catalog.contains_column("b"), true); diff --git a/src/planner/logical_create_table_plan.rs b/src/planner/logical_create_table_plan.rs index c13185dc..45a8b6f6 100644 --- a/src/planner/logical_create_table_plan.rs +++ b/src/planner/logical_create_table_plan.rs @@ -1,8 +1,8 @@ +use crate::catalog::ColumnDesc; + pub struct LogicalCreateTablePlan { - // pub database_id: DatabaseId, - // pub schema_id: SchemaId, - // pub table_name: String, - // pub columns: Vec<(String, ColumnDesc)>, + pub table_name: String, + pub columns: Vec<(String, ColumnDesc)>, } // use sqlparser::ast::{ColumnDef, ColumnOption, Statement}; diff --git a/src/planner/operator/join.rs b/src/planner/operator/join.rs index 9a7ac592..78bdc158 100644 --- a/src/planner/operator/join.rs +++ b/src/planner/operator/join.rs @@ -1,9 +1,6 @@ use std::sync::Arc; -use crate::{ - expression::ScalarExpression, - planner::{logical_select_plan::LogicalSelectPlan}, -}; +use crate::{expression::ScalarExpression, planner::logical_select_plan::LogicalSelectPlan}; use super::Operator; diff --git a/src/planner/operator/scan.rs b/src/planner/operator/scan.rs index 78ee03ad..bffd9cea 100644 --- a/src/planner/operator/scan.rs +++ b/src/planner/operator/scan.rs @@ -1,11 +1,10 @@ use std::sync::Arc; +use crate::types::TableId; use crate::{ - catalog::ColumnRefId, - expression::ScalarExpression, + catalog::ColumnRefId, expression::ScalarExpression, planner::logical_select_plan::LogicalSelectPlan, }; -use crate::types::TableId; use super::{sort::SortField, Operator}; diff --git a/src/types/mod.rs b/src/types/mod.rs index aca39275..73543e3d 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,9 +1,9 @@ pub mod value; -use std::sync::atomic::AtomicU32; -use std::sync::atomic::Ordering::{Acquire, Release}; use integer_encoding::FixedInt; pub use sqlparser::ast::DataType as DataTypeKind; +use std::sync::atomic::AtomicU32; +use std::sync::atomic::Ordering::{Acquire, Release}; static ID_BUF: AtomicU32 = AtomicU32::new(0); @@ -45,13 +45,11 @@ impl DataTypeExt for DataTypeKind { } } -pub(crate) struct IdGenerator { } +pub(crate) struct IdGenerator {} impl IdGenerator { pub(crate) fn encode_to_raw() -> Vec { - ID_BUF - .load(Acquire) - .encode_fixed_vec() + ID_BUF.load(Acquire).encode_fixed_vec() } pub(crate) fn from_raw(buf: &[u8]) { @@ -70,11 +68,10 @@ impl IdGenerator { pub type TableId = u32; pub type ColumnId = u32; - #[cfg(test)] mod test { + use crate::types::{IdGenerator, ID_BUF}; use std::sync::atomic::Ordering::Release; - use crate::types::{ID_BUF, IdGenerator}; /// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰 #[test] @@ -97,4 +94,4 @@ mod test { fn test_id_generator_reset() { ID_BUF.store(0, Release) } -} \ No newline at end of file +}