Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/binder/create.rs
Original file line number Diff line number Diff line change
@@ -1 +1,55 @@
use super::Binder;
use crate::binder::{lower_case_name, split_name};
use crate::catalog::{Column, ColumnDesc};
use crate::planner::logical_create_table_plan::LogicalCreateTablePlan;
use crate::planner::LogicalPlan;
use crate::types::ColumnId;
use anyhow::Result;
use sqlparser::ast::{ColumnDef, ObjectName};
use std::collections::HashSet;

impl Binder {
pub(super) fn bind_create_table(
&mut self,
name: ObjectName,
columns: &[ColumnDef],
) -> Result<LogicalCreateTablePlan> {
let name = lower_case_name(&name);

let (_, table_name) = split_name(&name)?;

let table = self
.context
.catalog
.get_table_by_name(table_name)
.ok_or_else(|| {
anyhow::Error::msg(format!("table {} not found", table_name.to_string()))
})?;

// check duplicated column names
let mut set = HashSet::new();
for col in columns.iter() {
if !set.insert(col.name.value.clone()) {
return Err(anyhow::Error::msg(format!(
"bind duplicated column {}",
col.name.value.clone()
)));
}
}

let mut columns: Vec<Column> = columns
.iter()
.enumerate()
.map(|(_, col)| Column::from(col))
.collect();

let plan = LogicalCreateTablePlan {
table_name: table_name.to_string(),
columns: columns
.into_iter()
.map(|col| (col.name.to_string(), col.desc.clone()))
.collect(),
};
Ok(plan)
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code patch appears to be an implementation of the bind_create_table method in a Binder struct. Here are some observations and suggestions for improvement:

  1. It seems that the code is missing necessary import statements at the beginning. Make sure to include the required imports, such as super, Binder, lower_case_name, split_name, etc.

  2. Consider adding proper documentation or comments to describe the purpose and functionality of the bind_create_table method. This will make it easier for other developers (including yourself) to understand the code in the future.

  3. It's generally a good practice to handle potential errors gracefully. For example, when retrieving a table by name (self.context.catalog.get_table_by_name(table_name)), consider using the Result type along with proper error handling instead of directly calling .ok_or_else() and returning an Error. This will provide more information about the specific error that occurred during table retrieval.

  4. The code checks for duplicated column names using a HashSet. This is a valid approach, but it would be helpful to provide a more specific error message indicating which column names are duplicated.

  5. When mapping the columns from ColumnDef to Column, the line let mut col = Column::from(col); seems unnecessary since the Column::from function can directly return a new instance without needing to mutate it.

  6. Check if there are any additional validations or constraints that need to be implemented. For example, you may want to verify the data types or enforce certain properties on the columns being created.

  7. Consider adding tests to cover different scenarios, such as creating a table with no columns, creating a table with duplicate columns, etc. This will ensure the correctness of the implementation and help catch any potential issues.

Remember that these suggestions are general guidelines, and your specific project requirements may vary.

34 changes: 27 additions & 7 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@ mod select;

use std::collections::HashMap;

use crate::{
catalog::CatalogRef,
expression::ScalarExpression,
planner::LogicalPlan,
};
use crate::{catalog::CatalogRef, expression::ScalarExpression, planner::LogicalPlan};

use anyhow::Result;
use sqlparser::ast::Statement;
use crate::catalog::DEFAULT_SCHEMA_NAME;
use crate::types::TableId;
use anyhow::Result;
use sqlparser::ast::{Ident, ObjectName, Statement};

pub struct BinderContext {
catalog: CatalogRef,
Expand Down Expand Up @@ -66,8 +63,31 @@ impl Binder {
let plan = self.bind_query(query)?;
LogicalPlan::Select(plan)
}
Statement::CreateTable { name, columns, .. } => {
let plan = self.bind_create_table(name.to_owned(), &columns)?;
LogicalPlan::CreateTable(plan)
}
_ => unimplemented!(),
};
Ok(plan)
}
}

/// Convert an object name into lower case
fn lower_case_name(name: &ObjectName) -> ObjectName {
ObjectName(
name.0
.iter()
.map(|ident| Ident::new(ident.value.to_lowercase()))
.collect(),
)
}

/// Split an object name into `(schema name, table name)`.
fn split_name(name: &ObjectName) -> Result<(&str, &str)> {
Ok(match name.0.as_slice() {
[table] => (DEFAULT_SCHEMA_NAME, &table.value),
[schema, table] => (&schema.value, &table.value),
_ => return Err(anyhow::anyhow!("Invalid table name: {:?}", name)),
})
}
5 changes: 3 additions & 2 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ use crate::{

use super::Binder;

use crate::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME};
use anyhow::Result;
use itertools::Itertools;
use sqlparser::ast::{
Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select,
SelectItem, SetExpr, TableFactor, TableWithJoins,
};
use crate::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME};

impl Binder {
pub(super) fn bind_query(&mut self, query: &Query) -> Result<LogicalSelectPlan> {
Expand Down Expand Up @@ -128,7 +128,8 @@ impl Binder {
.map(|ident| Ident::new(ident.value.to_lowercase()))
.collect_vec();

let (_database, _schema, mut table): (&str, &str, &str) = match obj_name.as_slice() {
let (_database, _schema, mut table): (&str, &str, &str) = match obj_name.as_slice()
{
[table] => (DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, &table.value),
[schema, table] => (DEFAULT_DATABASE_NAME, &schema.value, &table.value),
[database, schema, table] => (&database.value, &schema.value, &table.value),
Expand Down
28 changes: 24 additions & 4 deletions src/catalog/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::types::{ColumnId, DataType, IdGenerator};
use sqlparser::ast::{ColumnDef, ColumnOption};

#[derive(Clone)]
pub struct Column {
Expand All @@ -8,10 +9,7 @@ pub struct Column {
}

impl Column {
pub(crate) fn new(
column_name: String,
column_desc: ColumnDesc,
) -> Column {
pub(crate) fn new(column_name: String, column_desc: ColumnDesc) -> Column {
Column {
id: IdGenerator::build(),
name: column_name,
Expand Down Expand Up @@ -63,6 +61,28 @@ impl ColumnDesc {
}
}

impl From<&ColumnDef> for Column {
fn from(cdef: &ColumnDef) -> Self {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable name definition is too easy to confuse!

let mut is_nullable = true;
let mut is_primary_ = false;
for opt in &cdef.options {
match opt.option {
ColumnOption::Null => is_nullable = true,
ColumnOption::NotNull => is_nullable = false,
ColumnOption::Unique { is_primary } => is_primary_ = is_primary,
_ => todo!("column options"),
}
}
Column::new(
cdef.name.value.clone(),
ColumnDesc::new(
DataType::new(cdef.data_type.clone(), is_nullable),
is_primary_,
),
)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
16 changes: 14 additions & 2 deletions src/catalog/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ pub struct Root {
impl Root {
#[allow(dead_code)]
pub(crate) fn new() -> Self {
Root { table_idxs: Default::default(), tables: Default::default() }
Root {
table_idxs: Default::default(),
tables: Default::default(),
}
}

pub(crate) fn get_table_id_by_name(&self, name: &str) -> Option<TableId> {
Expand All @@ -21,7 +24,16 @@ impl Root {
self.tables.get(&table_id)
}

pub(crate) fn add_table(&mut self, table_name: String, columns: Vec<Column>) -> Result<TableId, CatalogError> {
pub(crate) fn get_table_by_name(&self, name: &str) -> Option<&Table> {
let id = self.table_idxs.get(name)?;
self.tables.get(id)
}

pub(crate) fn add_table(
&mut self,
table_name: String,
columns: Vec<Column>,
) -> Result<TableId, CatalogError> {
if self.table_idxs.contains_key(&table_name) {
return Err(CatalogError::Duplicated("column", table_name));
}
Expand Down
27 changes: 8 additions & 19 deletions src/catalog/table.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::{BTreeMap, HashMap};
use itertools::Itertools;
use crate::catalog::{CatalogError, Column};
use crate::types::{ColumnId, IdGenerator, TableId};
use itertools::Itertools;
use std::collections::{BTreeMap, HashMap};

pub struct Table {
pub id: TableId,
Expand All @@ -25,18 +25,16 @@ impl Table {
}

pub(crate) fn get_all_columns(&self) -> Vec<(ColumnId, &Column)> {
self.columns.iter()
self.columns
.iter()
.map(|(col_id, col)| (*col_id, col))
.collect_vec()
}

/// Add a column to the table catalog.
pub(crate) fn add_column(&mut self, col_catalog: Column) -> Result<ColumnId, CatalogError> {
if self.column_idxs.contains_key(&col_catalog.name) {
return Err(CatalogError::Duplicated(
"column",
col_catalog.name.into(),
));
return Err(CatalogError::Duplicated("column", col_catalog.name.into()));
}

let col_id = col_catalog.id;
Expand Down Expand Up @@ -74,19 +72,10 @@ mod tests {
// | 1 | true |
// | 2 | false |
fn test_table_catalog() {
let col0 = Column::new(
"a".into(),
DataTypeKind::Int(None).not_null().to_column(),
);
let col1 = Column::new(
"b".into(),
DataTypeKind::Boolean.not_null().to_column()
);
let col0 = Column::new("a".into(), DataTypeKind::Int(None).not_null().to_column());
let col1 = Column::new("b".into(), DataTypeKind::Boolean.not_null().to_column());
let col_catalogs = vec![col0, col1];
let table_catalog = Table::new(
"test".to_string(),
col_catalogs
).unwrap();
let table_catalog = Table::new("test".to_string(), col_catalogs).unwrap();

assert_eq!(table_catalog.contains_column("a"), true);
assert_eq!(table_catalog.contains_column("b"), true);
Expand Down
8 changes: 4 additions & 4 deletions src/planner/logical_create_table_plan.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::catalog::ColumnDesc;

pub struct LogicalCreateTablePlan {
// pub database_id: DatabaseId,
// pub schema_id: SchemaId,
// pub table_name: String,
// pub columns: Vec<(String, ColumnDesc)>,
pub table_name: String,
pub columns: Vec<(String, ColumnDesc)>,
}

// use sqlparser::ast::{ColumnDef, ColumnOption, Statement};
Expand Down
5 changes: 1 addition & 4 deletions src/planner/operator/join.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::sync::Arc;

use crate::{
expression::ScalarExpression,
planner::{logical_select_plan::LogicalSelectPlan},
};
use crate::{expression::ScalarExpression, planner::logical_select_plan::LogicalSelectPlan};

use super::Operator;

Expand Down
5 changes: 2 additions & 3 deletions src/planner/operator/scan.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::sync::Arc;

use crate::types::TableId;
use crate::{
catalog::ColumnRefId,
expression::ScalarExpression,
catalog::ColumnRefId, expression::ScalarExpression,
planner::logical_select_plan::LogicalSelectPlan,
};
use crate::types::TableId;

use super::{sort::SortField, Operator};

Expand Down
15 changes: 6 additions & 9 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub mod value;

use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Release};
use integer_encoding::FixedInt;
pub use sqlparser::ast::DataType as DataTypeKind;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::{Acquire, Release};

static ID_BUF: AtomicU32 = AtomicU32::new(0);

Expand Down Expand Up @@ -45,13 +45,11 @@ impl DataTypeExt for DataTypeKind {
}
}

pub(crate) struct IdGenerator { }
pub(crate) struct IdGenerator {}

impl IdGenerator {
pub(crate) fn encode_to_raw() -> Vec<u8> {
ID_BUF
.load(Acquire)
.encode_fixed_vec()
ID_BUF.load(Acquire).encode_fixed_vec()
}

pub(crate) fn from_raw(buf: &[u8]) {
Expand All @@ -70,11 +68,10 @@ impl IdGenerator {
pub type TableId = u32;
pub type ColumnId = u32;


#[cfg(test)]
mod test {
use crate::types::{IdGenerator, ID_BUF};
use std::sync::atomic::Ordering::Release;
use crate::types::{ID_BUF, IdGenerator};

/// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰
#[test]
Expand All @@ -97,4 +94,4 @@ mod test {
fn test_id_generator_reset() {
ID_BUF.store(0, Release)
}
}
}