From 30ba28d41fc518cda9596af85597fa1ff90039d6 Mon Sep 17 00:00:00 2001 From: Xwg Date: Thu, 13 Jul 2023 23:48:05 +0800 Subject: [PATCH] feat(type): use arrow type to support vectorization - change catalog name to avoid confusion. - remove some unused import but there are some possible uses that have not been removed. - memory storage is not complete. - remove gpt code reviewer. - change ci workflow. - add toolchain of rust --- .github/workflows/ci.yml | 51 +-- .github/workflows/cr.yml | 28 -- Cargo.toml | 4 + rust-toolchain | 1 + src/binder/create.rs | 37 +- src/binder/expr.rs | 2 +- src/binder/mod.rs | 31 +- src/binder/select.rs | 11 +- src/catalog/column.rs | 91 ++--- src/catalog/mod.rs | 11 +- src/catalog/root.rs | 37 +- src/catalog/table.rs | 49 +-- src/db.rs | 91 ++++- src/expression/agg.rs | 5 +- src/expression/mod.rs | 28 +- src/lib.rs | 3 + src/main.rs | 8 +- src/planner/logical_plan_builder.rs | 9 - src/planner/mod.rs | 1 + src/storage/memory.rs | 187 +++++++++ src/storage/mod.rs | 99 ++--- src/types/errors.rs | 11 + src/types/mod.rs | 357 ++++++++++++++-- src/types/value.rs | 612 +++++++++++++++++++++++++++- 24 files changed, 1406 insertions(+), 358 deletions(-) delete mode 100644 .github/workflows/cr.yml create mode 100644 rust-toolchain create mode 100644 src/storage/memory.rs create mode 100644 src/types/errors.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9ffb632e..0abd0a06 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,47 +8,40 @@ env: CARGO_TERM_COLOR: always jobs: - check: - runs-on: ubuntu-20.04 + fmt: + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly-2023-04-07 components: rustfmt, clippy - - name: Check code format - uses: actions-rs/cargo@v1 + - uses: actions/cache@v3 with: - command: fmt - args: --all -- --check + path: | + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - name: Check code format + run: cargo fmt --all -- --check - build: - runs-on: ubuntu-20.04 - steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly-2023-04-07 - - uses: actions/checkout@v2 - - name: Build - uses: actions-rs/cargo@v1 - with: - command: build test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly-2023-04-07 - - uses: actions/checkout@v2 - - name: Test - uses: actions-rs/cargo@v1 + - uses: actions/cache@v3 with: - command: test - args: --release --no-fail-fast + path: | + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - uses: taiki-e/install-action@nextest + - name: Test + run: cargo nextest run --no-fail-fast --all-features diff --git a/.github/workflows/cr.yml b/.github/workflows/cr.yml deleted file mode 100644 index c7a80868..00000000 --- a/.github/workflows/cr.yml +++ /dev/null @@ -1,28 +0,0 @@ -name: Code Review - -permissions: - contents: read - pull-requests: write - -on: - pull_request: - types: [opened, reopened, synchronize] - -jobs: - test: - # if: ${{ contains(github.event.*.labels.*.name, 'gpt review') }} # Optional; to run only when a label is attached - runs-on: ubuntu-latest - steps: - - uses: anc95/ChatGPT-CodeReview@main - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - # Optional - LANGUAGE: Chinese - OPENAI_API_ENDPOINT: https://api.openai.com/v1 - MODEL: gpt-3.5-turbo - PROMPT: - top_p: 1 - temperature: 1 - max_tokens: 10000 - MAX_PATCH_LENGTH: 10000 # if the patch/diff length is large than MAX_PATCH_LENGTH, will be ignored and won't review. By default, with no MAX_PATCH_LENGTH set, there is also no limit for the patch/diff length. \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index c37c228f..7ec0aff4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,9 @@ serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" async-trait = "0.1.68" integer-encoding = "3.0.4" +arrow = { version = "28", features = ["prettyprint", "simd"] } +strum_macros = "0.24" +ordered-float = "3.0" petgraph = "0.6.3" futures-async-stream = "0.2.6" async-channel = "1.8.0" @@ -38,6 +41,7 @@ async-backtrace = "0.2.6" futures = "0.3.25" futures-lite = "1.12.0" + [dev-dependencies] ctor = "0.2.0" env_logger = "0.10" diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 00000000..ab84e227 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly-2023-04-07 diff --git a/src/binder/create.rs b/src/binder/create.rs index ce77638a..b1c65bc3 100644 --- a/src/binder/create.rs +++ b/src/binder/create.rs @@ -1,12 +1,12 @@ +use std::collections::HashSet; + +use anyhow::Result; +use sqlparser::ast::{ColumnDef, ObjectName}; + use super::Binder; use crate::binder::{lower_case_name, split_name}; -use crate::catalog::{Column, ColumnDesc}; +use crate::catalog::ColumnCatalog; use crate::planner::logical_create_table_plan::LogicalCreateTablePlan; -use crate::planner::LogicalPlan; -use crate::types::{ColumnId, TableId}; -use anyhow::Result; -use sqlparser::ast::{ColumnDef, ObjectName}; -use std::collections::HashSet; impl Binder { pub(crate) fn bind_create_table( @@ -29,10 +29,10 @@ impl Binder { } } - let mut columns: Vec = columns + let columns: Vec = columns .iter() .enumerate() - .map(|(_, col)| Column::from(col)) + .map(|(_, col)| ColumnCatalog::from(col.clone())) .collect(); let plan = LogicalCreateTablePlan { @@ -48,36 +48,31 @@ impl Binder { #[cfg(test)] mod tests { + use sqlparser::ast::CharacterLength; + use super::*; use crate::binder::BinderContext; - use crate::catalog::Root; - use crate::types::{DataTypeExt, DataTypeKind}; - use sqlparser::ast::CharacterLength; - use std::sync::Arc; + use crate::catalog::{ColumnDesc, RootCatalog}; + use crate::planner::LogicalPlan; + use crate::types::LogicalType; #[test] fn test_create_bind() { let sql = "create table t1 (id int , name varchar(10))"; - let mut binder = Binder::new(BinderContext::new(Root::new())); + let binder = Binder::new(BinderContext::new(RootCatalog::new())); let stmt = crate::parser::parse_sql(sql).unwrap(); let plan1 = binder.bind(&stmt[0]).unwrap(); - let character_length = CharacterLength { - length: 10, - unit: None, - }; let plan2 = LogicalPlan::CreateTable(LogicalCreateTablePlan { table_name: "t1".to_string(), columns: vec![ ( "id".to_string(), - DataTypeKind::Int(None).nullable().to_column(), + ColumnDesc::new(LogicalType::Integer, false), ), ( "name".to_string(), - DataTypeKind::Varchar(Option::from(character_length)) - .nullable() - .to_column(), + ColumnDesc::new(LogicalType::Varchar, false), ), ], }); diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 335734ef..7df3d422 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -1,8 +1,8 @@ +use anyhow::Result; use sqlparser::ast::Expr; use super::Binder; use crate::expression::ScalarExpression; -use anyhow::Result; impl Binder { pub(crate) fn bind_expr(&mut self, expr: &Expr) -> Result { diff --git a/src/binder/mod.rs b/src/binder/mod.rs index ab60fef4..b2a6f8f9 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -5,15 +5,16 @@ mod select; use std::collections::HashMap; -use crate::{catalog::CatalogRef, expression::ScalarExpression, planner::LogicalPlan}; - -use crate::catalog::{Root, DEFAULT_SCHEMA_NAME}; -use crate::types::TableId; use anyhow::Result; use sqlparser::ast::{Ident, ObjectName, Statement}; + +use crate::catalog::{RootCatalog, DEFAULT_SCHEMA_NAME}; +use crate::expression::ScalarExpression; +use crate::planner::LogicalPlan; +use crate::types::TableId; #[derive(Clone)] pub struct BinderContext { - catalog: Root, + catalog: RootCatalog, bind_table: HashMap, aliases: HashMap, group_by_exprs: Vec, @@ -22,7 +23,7 @@ pub struct BinderContext { } impl BinderContext { - pub fn new(catalog: Root) -> Self { + pub fn new(catalog: RootCatalog) -> Self { BinderContext { catalog, bind_table: Default::default(), @@ -91,3 +92,21 @@ fn split_name(name: &ObjectName) -> Result<(&str, &str)> { _ => return Err(anyhow::anyhow!("Invalid table name: {:?}", name)), }) } + +#[derive(thiserror::Error, Debug)] +pub enum BindError { + #[error("unsupported statement {0}")] + UnsupportedStmt(String), + #[error("invalid table {0}")] + InvalidTable(String), + #[error("invalid table name: {0:?}")] + InvalidTableName(Vec), + #[error("invalid column {0}")] + InvalidColumn(String), + #[error("ambiguous column {0}")] + AmbiguousColumn(String), + #[error("binary operator types mismatch: {0} != {1}")] + BinaryOpTypeMismatch(String, String), + #[error("subquery in FROM must have an alias")] + SubqueryMustHaveAlias, +} diff --git a/src/binder/select.rs b/src/binder/select.rs index dd6453db..5e7e5bed 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -166,8 +166,7 @@ impl Binder { /// /// - Qualified name, e.g. `SELECT t.a FROM t` /// - Qualified name with wildcard, e.g. `SELECT t.* FROM t,t1` - /// - Scalar expression or aggregate expression, e.g. `SELECT COUNT(*) + 1 - /// AS count FROM t` + /// - Scalar expression or aggregate expression, e.g. `SELECT COUNT(*) + 1 AS count FROM t` /// fn normalize_select_item(&mut self, items: &[SelectItem]) -> Result> { let mut select_items = vec![]; @@ -293,8 +292,8 @@ impl Binder { let expr = self.bind_expr(expr)?; match expr { ScalarExpression::Constant(dv) => match dv { - DataValue::Int32(v) if v > 0 => limit = v as usize, - DataValue::Int64(v) if v > 0 => limit = v as usize, + DataValue::Int32(Some(v)) if v > 0 => limit = v as usize, + DataValue::Int64(Some(v)) if v > 0 => limit = v as usize, _ => return Err(anyhow::Error::msg("invalid limit expression.".to_owned())), }, _ => return Err(anyhow::Error::msg("invalid limit expression.".to_owned())), @@ -305,8 +304,8 @@ impl Binder { let expr = self.bind_expr(&expr.value)?; match expr { ScalarExpression::Constant(dv) => match dv { - DataValue::Int32(v) if v > 0 => offset = v as usize, - DataValue::Int64(v) if v > 0 => offset = v as usize, + DataValue::Int32(Some(v)) if v > 0 => offset = v as usize, + DataValue::Int64(Some(v)) if v > 0 => offset = v as usize, _ => return Err(anyhow::Error::msg("invalid limit expression.".to_owned())), }, _ => return Err(anyhow::Error::msg("invalid offset expression.".to_owned())), diff --git a/src/catalog/column.rs b/src/catalog/column.rs index ae3e0a58..91badb23 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -1,23 +1,25 @@ -use crate::types::{ColumnId, DataType, IdGenerator}; -use sqlparser::ast::{ColumnDef, ColumnOption}; +use arrow::datatypes::{DataType, Field}; +use sqlparser::ast::ColumnDef; + +use crate::types::{ColumnId, IdGenerator, LogicalType}; #[derive(Debug, Clone)] -pub struct Column { +pub struct ColumnCatalog { pub id: ColumnId, pub name: String, pub desc: ColumnDesc, } -impl Column { - pub(crate) fn new(column_name: String, column_desc: ColumnDesc) -> Column { - Column { +impl ColumnCatalog { + pub(crate) fn new(column_name: String, column_desc: ColumnDesc) -> ColumnCatalog { + ColumnCatalog { id: IdGenerator::build(), name: column_name, desc: column_desc, } } - pub(crate) fn datatype(&self) -> &DataType { + pub(crate) fn datatype(&self) -> &LogicalType { &self.desc.column_datatype } @@ -28,28 +30,35 @@ impl Column { pub fn desc(&self) -> &ColumnDesc { &self.desc } -} -impl DataType { - #[inline] - pub const fn to_column(self) -> ColumnDesc { - ColumnDesc::new(self, false) + pub fn to_field(&self) -> Field { + Field::new( + &*self.name.clone(), + DataType::from(self.desc.column_datatype.clone()), + self.desc.is_primary(), + ) } - #[inline] - pub const fn to_column_primary_key(self) -> ColumnDesc { - ColumnDesc::new(self, true) +} + +impl From for ColumnCatalog { + fn from(column_def: ColumnDef) -> Self { + let column_name = column_def.name.to_string(); + let column_datatype = LogicalType::try_from(column_def.data_type).unwrap(); + let is_primary = false; + let column_desc = ColumnDesc::new(column_datatype, is_primary); + ColumnCatalog::new(column_name, column_desc) } } /// The descriptor of a column. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ColumnDesc { - column_datatype: DataType, - is_primary: bool, + pub(crate) column_datatype: LogicalType, + pub(crate) is_primary: bool, } impl ColumnDesc { - pub(crate) const fn new(column_datatype: DataType, is_primary: bool) -> ColumnDesc { + pub(crate) const fn new(column_datatype: LogicalType, is_primary: bool) -> ColumnDesc { ColumnDesc { column_datatype, is_primary, @@ -60,51 +69,7 @@ impl ColumnDesc { self.is_primary } - pub(crate) fn is_nullable(&self) -> bool { - self.column_datatype.is_nullable() - } - - pub(crate) fn get_datatype(&self) -> DataType { + pub(crate) fn get_datatype(&self) -> LogicalType { self.column_datatype.clone() } } - -impl From<&ColumnDef> for Column { - fn from(cdef: &ColumnDef) -> Self { - let mut is_nullable = true; - let mut is_primary_ = false; - for opt in &cdef.options { - match opt.option { - ColumnOption::Null => is_nullable = true, - ColumnOption::NotNull => is_nullable = false, - ColumnOption::Unique { is_primary } => is_primary_ = is_primary, - _ => todo!("column options"), - } - } - Column::new( - cdef.name.value.clone(), - ColumnDesc::new( - DataType::new(cdef.data_type.clone(), is_nullable), - is_primary_, - ), - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::types::{DataTypeExt, DataTypeKind}; - - #[test] - fn test_column_catalog() { - let mut col_catalog = Column::new( - "test".to_string(), - DataTypeKind::Int(None).not_null().to_column(), - ); - - assert_eq!(col_catalog.desc.is_primary(), false); - assert_eq!(col_catalog.desc.is_nullable(), false); - assert_eq!(col_catalog.name, "test"); - } -} diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index 26ef971a..b4d0361c 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -1,15 +1,13 @@ // Module: catalog +use std::sync::Arc; + pub(crate) use self::column::*; pub(crate) use self::root::*; pub(crate) use self::table::*; - use crate::types::{ColumnId, TableId}; -use std::sync::Arc; /// The type of catalog reference. -pub type CatalogRef = Arc; -pub(crate) type TableRef = Arc; -pub(crate) type ColumnRef = Arc; +pub type CatalogRef = Arc; pub(crate) static DEFAULT_DATABASE_NAME: &str = "kipsql"; pub(crate) static DEFAULT_SCHEMA_NAME: &str = "kipsql"; @@ -45,8 +43,7 @@ impl TableRefId { } } -/// The error type of catalog operations. -#[derive(thiserror::Error, Debug, PartialEq, Eq)] +#[derive(thiserror::Error, Debug)] pub enum CatalogError { #[error("{0} not found: {1}")] NotFound(&'static str, String), diff --git a/src/catalog/root.rs b/src/catalog/root.rs index 64f627f6..2fc3c882 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -1,23 +1,24 @@ -use crate::catalog::{CatalogError, Column, Table}; -use crate::types::TableId; use std::collections::BTreeMap; +use crate::catalog::{CatalogError, ColumnCatalog, TableCatalog}; +use crate::types::TableId; + #[derive(Debug, Clone)] -pub struct Root { +pub struct RootCatalog { table_idxs: BTreeMap, - tables: BTreeMap, + tables: BTreeMap, } -impl Default for Root { +impl Default for RootCatalog { fn default() -> Self { Self::new() } } -impl Root { +impl RootCatalog { #[allow(dead_code)] pub fn new() -> Self { - Root { + RootCatalog { table_idxs: Default::default(), tables: Default::default(), } @@ -27,11 +28,11 @@ impl Root { self.table_idxs.get(name).cloned() } - pub(crate) fn get_table(&self, table_id: TableId) -> Option<&Table> { + pub(crate) fn get_table(&self, table_id: TableId) -> Option<&TableCatalog> { self.tables.get(&table_id) } - pub(crate) fn get_table_by_name(&self, name: &str) -> Option<&Table> { + pub(crate) fn get_table_by_name(&self, name: &str) -> Option<&TableCatalog> { let id = self.table_idxs.get(name)?; self.tables.get(id) } @@ -39,12 +40,12 @@ impl Root { pub(crate) fn add_table( &mut self, table_name: String, - columns: Vec, + columns: Vec, ) -> Result { if self.table_idxs.contains_key(&table_name) { return Err(CatalogError::Duplicated("column", table_name)); } - let table = Table::new(table_name.to_owned(), columns)?; + let table = TableCatalog::new(table_name.to_owned(), columns)?; let table_id = table.id; self.table_idxs.insert(table_name, table_id); @@ -57,20 +58,20 @@ impl Root { #[cfg(test)] mod tests { use super::*; - use crate::catalog::Column; - use crate::types::{DataTypeExt, DataTypeKind}; + use crate::catalog::{ColumnCatalog, ColumnDesc}; + use crate::types::LogicalType; #[test] fn test_root_catalog() { - let mut root_catalog = Root::new(); + let mut root_catalog = RootCatalog::new(); - let col0 = Column::new( + let col0 = ColumnCatalog::new( "a".to_string(), - DataTypeKind::Int(None).not_null().to_column(), + ColumnDesc::new(LogicalType::Integer, false), ); - let col1 = Column::new( + let col1 = ColumnCatalog::new( "b".to_string(), - DataTypeKind::Boolean.not_null().to_column(), + ColumnDesc::new(LogicalType::Boolean, false), ); let col_catalogs = vec![col0, col1]; diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 93fc4eee..a7224efd 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -1,18 +1,20 @@ -use crate::catalog::{CatalogError, Column}; -use crate::types::{ColumnId, IdGenerator, TableId}; -use itertools::Itertools; use std::collections::{BTreeMap, HashMap}; + +use itertools::Itertools; + +use crate::catalog::{CatalogError, ColumnCatalog}; +use crate::types::{ColumnId, IdGenerator, TableId}; #[derive(Debug, Clone)] -pub struct Table { +pub struct TableCatalog { pub id: TableId, pub name: String, /// Mapping from column names to column ids column_idxs: HashMap, - columns: BTreeMap, + pub(crate) columns: BTreeMap, } -impl Table { - pub(crate) fn get_column_by_id(&self, id: ColumnId) -> Option<&Column> { +impl TableCatalog { + pub(crate) fn get_column_by_id(&self, id: ColumnId) -> Option<&ColumnCatalog> { self.columns.get(&id) } @@ -24,7 +26,7 @@ impl Table { self.column_idxs.contains_key(name) } - pub(crate) fn get_all_columns(&self) -> Vec<(ColumnId, &Column)> { + pub(crate) fn get_all_columns(&self) -> Vec<(ColumnId, &ColumnCatalog)> { self.columns .iter() .map(|(col_id, col)| (*col_id, col)) @@ -32,7 +34,10 @@ impl Table { } /// Add a column to the table catalog. - pub(crate) fn add_column(&mut self, col_catalog: Column) -> Result { + pub(crate) fn add_column( + &mut self, + col_catalog: ColumnCatalog, + ) -> Result { if self.column_idxs.contains_key(&col_catalog.name) { return Err(CatalogError::Duplicated("column", col_catalog.name.into())); } @@ -45,8 +50,11 @@ impl Table { Ok(col_id) } - pub(crate) fn new(table_name: String, columns: Vec) -> Result { - let mut table_catalog = Table { + pub(crate) fn new( + table_name: String, + columns: Vec, + ) -> Result { + let mut table_catalog = TableCatalog { id: IdGenerator::build(), name: table_name, column_idxs: HashMap::new(), @@ -64,7 +72,8 @@ impl Table { #[cfg(test)] mod tests { use super::*; - use crate::types::{DataType, DataTypeExt, DataTypeKind}; + use crate::catalog::ColumnDesc; + use crate::types::LogicalType; #[test] // | a (Int32) | b (Bool) | @@ -72,10 +81,10 @@ mod tests { // | 1 | true | // | 2 | false | fn test_table_catalog() { - let col0 = Column::new("a".into(), DataTypeKind::Int(None).not_null().to_column()); - let col1 = Column::new("b".into(), DataTypeKind::Boolean.not_null().to_column()); + let col0 = ColumnCatalog::new("a".into(), ColumnDesc::new(LogicalType::Integer, false)); + let col1 = ColumnCatalog::new("b".into(), ColumnDesc::new(LogicalType::Boolean, false)); let col_catalogs = vec![col0, col1]; - let table_catalog = Table::new("test".to_string(), col_catalogs).unwrap(); + let table_catalog = TableCatalog::new("test".to_string(), col_catalogs).unwrap(); assert_eq!(table_catalog.contains_column("a"), true); assert_eq!(table_catalog.contains_column("b"), true); @@ -87,16 +96,10 @@ mod tests { let column_catalog = table_catalog.get_column_by_id(col_a_id).unwrap(); assert_eq!(column_catalog.name, "a"); - assert_eq!( - column_catalog.datatype(), - &DataType::new(DataTypeKind::Int(None), false) - ); + assert_eq!(*column_catalog.datatype(), LogicalType::Integer,); let column_catalog = table_catalog.get_column_by_id(col_b_id).unwrap(); assert_eq!(column_catalog.name, "b"); - assert_eq!( - column_catalog.datatype(), - &DataType::new(DataTypeKind::Boolean, false) - ); + assert_eq!(*column_catalog.datatype(), LogicalType::Boolean,); } } diff --git a/src/db.rs b/src/db.rs index 32279638..ea84457f 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,26 +1,33 @@ -use crate::binder::{Binder, BinderContext}; -use crate::catalog::{CatalogRef, Column, Root}; -use crate::parser::parse_sql; -use crate::planner::logical_create_table_plan::LogicalCreateTablePlan; -use crate::planner::LogicalPlan; +use std::sync::Arc; -use crate::storage::{InMemoryStorage, Storage}; use anyhow::Result; +use arrow::datatypes::Schema; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use sqlparser::parser::ParserError; + +use crate::binder::{BindError, Binder, BinderContext}; +use crate::catalog::ColumnCatalog; +use crate::parser::parse_sql; +use crate::planner::LogicalPlan; +use crate::storage::memory::InMemoryStorage; +use crate::storage::{Storage, StorageError}; +use crate::types::IdGenerator; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Database { pub storage: InMemoryStorage, } impl Default for Database { fn default() -> Self { - Self::new() + Self::new_on_mem() } } impl Database { /// Create a new Database instance. - pub fn new() -> Self { + pub fn new_on_mem() -> Self { let storage = InMemoryStorage::new(); Database { storage } } @@ -30,8 +37,9 @@ impl Database { // parse let stmts = parse_sql(sql)?; // bind - let catalog = self.storage.catalog(); - let mut binder = Binder::new(BinderContext::new(catalog.clone())); + let catalog = self.storage.get_catalog(); + + let binder = Binder::new(BinderContext::new(catalog.clone())); /// Build a logical plan. /// @@ -41,28 +49,67 @@ impl Database { /// Limit(1) /// Project(a,b) let logical_plan = binder.bind(&stmts[0])?; + println!("logic plan {:?}", logical_plan); - println!("{:?}", logical_plan); + // let physical_planner = PhysicalPlaner::default(); + // let executor_builder = ExecutorBuilder::new(self.env.clone()); - //let physical_planner = PhysicalPlaner::default(); - //let executor_builder = ExecutorBuilder::new(self.env.clone()); - - //let physical_plan = physical_planner.plan(logical_plan)?; - //let executor = executor_builder.build(physical_plan)?; - //futures::executor::block_on(executor).unwrap(); + // let physical_plan = physical_planner.plan(logical_plan)?; + // let executor = executor_builder.build(physical_plan)?; + // futures::executor::block_on(executor).unwrap(); /// THE FOLLOWING CODE IS FOR TESTING ONLY /// THE FINAL CODE WILL BE IN executor MODULE if let LogicalPlan::CreateTable(plan) = logical_plan { - let mut colums = Vec::new(); + let mut columns = Vec::new(); plan.columns.iter().for_each(|c| { - colums.push(Column::new(c.0.clone(), c.1.clone())); + columns.push(ColumnCatalog::new(c.0.clone(), c.1.clone())); + }); + let table_name = plan.table_name.clone(); + // columns->batch record + let mut data = Vec::new(); + + columns.iter().for_each(|c| { + let batch = RecordBatch::new_empty(Arc::new(Schema::new(vec![c.to_field()]))); + data.push(batch); }); - let mut table_name = plan.table_name.clone(); + self.storage - .create_table(&table_name.to_string(), &colums.clone())?; + .create_table(IdGenerator::build(), table_name.as_str(), data)?; } Ok(()) } } + +#[derive(thiserror::Error, Debug)] +pub enum DatabaseError { + #[error("parse error: {0}")] + Parse( + #[source] + #[from] + ParserError, + ), + #[error("bind error: {0}")] + Bind( + #[source] + #[from] + BindError, + ), + #[error("Storage error: {0}")] + StorageError( + #[source] + #[from] + #[backtrace] + StorageError, + ), + #[error("Arrow error: {0}")] + ArrowError( + #[source] + #[from] + #[backtrace] + ArrowError, + ), + #[error("Internal error: {0}")] + InternalError(String), +} diff --git a/src/expression/agg.rs b/src/expression/agg.rs index 9f9b7a6e..4a57fa6c 100644 --- a/src/expression/agg.rs +++ b/src/expression/agg.rs @@ -1,6 +1,5 @@ -use crate::types::DataType; - use super::ScalarExpression; +use crate::types::LogicalType; #[derive(Debug, Clone, PartialEq)] pub enum AggKind { @@ -15,6 +14,6 @@ pub enum AggKind { pub struct AggCall { pub kind: AggKind, pub args: Vec, - pub return_type: DataType, + pub return_type: LogicalType, // TODO: add distinct keyword } diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 4b8766f8..cf258a75 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -1,13 +1,11 @@ use std::fmt::Display; -use crate::{ - catalog::{ColumnDesc, ColumnRefId}, - types::{value::DataValue, DataType}, -}; use sqlparser::ast::{BinaryOperator as SqlBinaryOperator, UnaryOperator as SqlUnaryOperator}; use self::agg::AggKind; -pub use sqlparser::ast::DataType as DataTypeKind; +use crate::catalog::{ColumnDesc, ColumnRefId}; +use crate::types::value::DataValue; +use crate::types::LogicalType; pub mod agg; mod evaluator; @@ -26,7 +24,7 @@ pub enum ScalarExpression { }, InputRef { index: usize, - ty: DataType, + ty: LogicalType, }, Alias { expr: Box, @@ -34,7 +32,7 @@ pub enum ScalarExpression { }, TypeCast { expr: Box, - ty: DataType, + ty: LogicalType, is_try: bool, }, IsNull { @@ -43,32 +41,32 @@ pub enum ScalarExpression { Unary { op: UnaryOperator, expr: Box, - ty: Option, + ty: LogicalType, }, Binary { op: BinaryOperator, left_expr: Box, right_expr: Box, - ty: Option, + ty: LogicalType, }, AggCall { kind: AggKind, args: Vec, - ty: DataType, + ty: LogicalType, }, } impl ScalarExpression { - pub fn return_type(&self) -> Option { + pub fn return_type(&self) -> Option { match self { - Self::Constant(v) => v.data_type(), + Self::Constant(v) => Some(v.get_logic_type().clone()), Self::ColumnRef { desc, .. } => Some(desc.get_datatype().clone()), Self::Binary { ty: return_type, .. - } => return_type.clone(), + } => Some(return_type.clone()), Self::Unary { ty: return_type, .. - } => return_type.clone(), + } => Some(return_type.clone()), Self::TypeCast { ty: return_type, .. } => Some(return_type.clone()), @@ -78,7 +76,7 @@ impl ScalarExpression { Self::InputRef { ty: return_type, .. } => Some(return_type.clone()), - Self::IsNull { .. } => Some(DataType::new(DataTypeKind::Boolean, false)), + Self::IsNull { .. } => Some(LogicalType::Boolean), Self::Alias { expr, .. } => expr.return_type(), } } diff --git a/src/lib.rs b/src/lib.rs index e35dc419..69963779 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,6 @@ +#![feature(error_generic_member_access)] +#![feature(provide_any)] +#![allow(unused_doc_comments)] // #![deny( // // The following are allowed by default lints according to // // https://doc.rust-lang.org/rustc/lints/listing/allowed-by-default.html diff --git a/src/main.rs b/src/main.rs index ad9e7fdc..7997477b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,16 @@ -use kip_sql::db::Database; use std::io; +use kip_sql::db::Database; +use kip_sql::storage::Storage; + #[tokio::main] async fn main() -> Result<(), Box> { println!(":) Welcome to the KIPSQL, Please input sql."); - let mut db = Database::new(); + let mut db = Database::new_on_mem(); loop { - println!("storage catalog{:?}", db.storage.catalog); println!("> "); + println!("RootCatalog: {:?}", db.storage.get_catalog()); let mut input = String::new(); io::stdin().read_line(&mut input)?; let ret = db.run(&input); diff --git a/src/planner/logical_plan_builder.rs b/src/planner/logical_plan_builder.rs index 2cd915b4..38395930 100644 --- a/src/planner/logical_plan_builder.rs +++ b/src/planner/logical_plan_builder.rs @@ -1,12 +1,3 @@ -use crate::binder::{Binder, BinderContext}; -use crate::catalog::Root; -use crate::parser; -use anyhow::Result; -use std::sync::Arc; - -use crate::planner::logical_select_plan::LogicalSelectPlan; -use crate::planner::LogicalPlan; - #[derive(Clone)] pub struct PlanBuilder {} diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 17c9420d..a2f4dbd2 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -13,3 +13,4 @@ pub enum LogicalPlan { Select(LogicalSelectPlan), CreateTable(LogicalCreateTablePlan), } +pub enum LogicalPlanError {} diff --git a/src/storage/memory.rs b/src/storage/memory.rs new file mode 100644 index 00000000..83a87a55 --- /dev/null +++ b/src/storage/memory.rs @@ -0,0 +1,187 @@ +use std::collections::HashMap; +use std::sync::Mutex; + +use arrow::record_batch::RecordBatch; + +use crate::catalog::{ColumnCatalog, ColumnDesc, RootCatalog}; +use crate::storage::{Bounds, Projections, Storage, StorageError, Table, Transaction}; +use crate::types::{LogicalType, TableId}; + +#[derive(Debug)] +pub struct InMemoryStorage { + catalog: Mutex, + tables: Mutex>, +} + +impl Default for InMemoryStorage { + fn default() -> Self { + Self::new() + } +} + +impl InMemoryStorage { + pub fn new() -> Self { + InMemoryStorage { + catalog: Mutex::new(RootCatalog::default()), + tables: Mutex::new(HashMap::new()), + } + } +} + +impl Storage for InMemoryStorage { + type TableType = InMemoryTable; + + fn create_table( + &mut self, + id: TableId, + table_name: &str, + data: Vec, + ) -> Result<(), StorageError> { + let table = InMemoryTable::new(id.clone(), table_name, data)?; + self.catalog + .lock() + .unwrap() + .add_table(table.table_name.clone(), table.columns_vec.clone()) + .unwrap(); + self.tables.lock().unwrap().insert(id, table); + Ok(()) + } + + fn get_table(&self, id: TableId) -> Result { + self.tables + .lock() + .unwrap() + .get(&id) + .cloned() + .ok_or(StorageError::TableNotFound(id)) + } + + fn get_catalog(&self) -> RootCatalog { + self.catalog.lock().unwrap().clone() + } + + fn show_tables(&self) -> Result { + todo!() + } +} + +#[derive(Debug, Clone)] +pub struct InMemoryTable { + table_id: TableId, + table_name: String, + data: Vec, + columns_vec: Vec, +} + +impl InMemoryTable { + pub fn new(id: TableId, name: &str, data: Vec) -> Result { + let columns = Self::infer_catalog(data.first().cloned()); + Ok(Self { + table_id: id, + table_name: name.to_string(), + data, + columns_vec: columns, + }) + } + + fn infer_catalog(batch: Option) -> Vec { + let mut columns = Vec::new(); + if let Some(batch) = batch { + for f in batch.schema().fields().iter() { + let field_name = f.name().to_string(); + let column_dec = + ColumnDesc::new(LogicalType::try_from(f.data_type()).unwrap(), false); + let column_catalog = ColumnCatalog::new(field_name, column_dec); + columns.push(column_catalog) + } + } + columns + } +} + +impl Table for InMemoryTable { + type TransactionType = InMemoryTransaction; + + fn read( + &self, + _bounds: Bounds, + _projection: Projections, + ) -> Result { + InMemoryTransaction::start(self) + } +} + +pub struct InMemoryTransaction { + batch_cursor: usize, + data: Vec, +} + +impl InMemoryTransaction { + pub fn start(table: &InMemoryTable) -> Result { + Ok(Self { + batch_cursor: 0, + data: table.data.clone(), + }) + } +} + +impl Transaction for InMemoryTransaction { + fn next_batch(&mut self) -> Result, StorageError> { + self.data + .get(self.batch_cursor) + .map(|batch| { + self.batch_cursor += 1; + Ok(batch.clone()) + }) + .transpose() + } +} + +#[cfg(test)] +mod storage_test { + use std::sync::Arc; + + use crate::types::IdGenerator; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + + use super::*; + + fn build_record_batch() -> Result, StorageError> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + ], + )?; + Ok(vec![batch]) + } + + #[test] + fn test_in_memory_storage_works_with_data() -> Result<(), StorageError> { + let id = IdGenerator::build(); + let mut storage = InMemoryStorage::new(); + storage.create_table(id.clone(), "test", build_record_batch()?)?; + + let catalog = storage.get_catalog(); + println!("{:?}", catalog); + let table_catalog = catalog.get_table_by_name("test"); + assert!(table_catalog.is_some()); + assert!(table_catalog.unwrap().get_column_id_by_name("a").is_some()); + + let table = storage.get_table(id)?; + let mut tx = table.read(None, None)?; + let batch = tx.next_batch()?; + println!("{:?}", batch); + assert!(batch.is_some()); + assert_eq!(batch.unwrap().num_rows(), 3); + + Ok(()) + } +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 78e582da..2d3071cf 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,62 +1,63 @@ -use crate::catalog::{CatalogRef, Column, Root, TableRefId}; -use crate::types::ColumnId; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +pub(crate) mod memory; -#[derive(thiserror::Error, Debug, PartialEq)] -pub enum StorageError { - #[error("failed to read table")] - ReadTableError, - #[error("failed to write table")] - WriteTableError, - #[error("{0}({1}) not found")] - NotFound(&'static str, u32), - #[error("duplicated {0}: {1}")] - Duplicated(&'static str, String), - #[error("invalid column id: {0}")] - InvalidColumn(ColumnId), +use std::io; + +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; + +use crate::catalog::RootCatalog; +use crate::storage::memory::InMemoryStorage; +use crate::types::TableId; + +#[derive(Debug)] +pub enum StorageImpl { + InMemoryStorage(InMemoryStorage), } pub trait Storage: Sync + Send { - fn create_table(&mut self, table_name: &str, columns: &Vec) - -> Result<(), StorageError>; -} + type TableType: Table; -pub type StorageRef = Arc; + fn create_table( + &mut self, + id: TableId, + table_name: &str, + columns: Vec, + ) -> Result<(), StorageError>; + fn get_table(&self, id: TableId) -> Result; + fn get_catalog(&self) -> RootCatalog; + fn show_tables(&self) -> Result; +} -#[derive(Debug, Clone)] -pub struct InMemoryStorage { - pub catalog: Root, +/// Optional bounds of the reader, of the form (offset, limit). +type Bounds = Option<(usize, usize)>; +type Projections = Option>; + +pub trait Table: Sync + Send + Clone + 'static { + type TransactionType: Transaction; + + /// The bounds is applied to the whole data batches, not per batch. + /// + /// The projections is column indices. + fn read( + &self, + bounds: Bounds, + projection: Projections, + ) -> Result; } -impl Default for InMemoryStorage { - fn default() -> Self { - Self::new() - } +// currently we use a transaction to hold csv reader +pub trait Transaction: Sync + Send + 'static { + fn next_batch(&mut self) -> Result, StorageError>; } -impl InMemoryStorage { - pub fn new() -> Self { - InMemoryStorage { - catalog: Root::new(), - } - } +#[derive(thiserror::Error, Debug)] +pub enum StorageError { + #[error("arrow error")] + ArrowError(#[from] ArrowError), - pub fn catalog(&self) -> &Root { - &self.catalog - } -} + #[error("io error")] + IoError(#[from] io::Error), -impl Storage for InMemoryStorage { - fn create_table( - &mut self, - table_name: &str, - column_descs: &Vec, - ) -> Result<(), StorageError> { - let table_id = self - .catalog - .add_table(table_name.into(), column_descs.to_vec()) - .map_err(|_| StorageError::Duplicated("table", table_name.into()))?; - Ok(()) - } + #[error("table not found: {0}")] + TableNotFound(TableId), } diff --git a/src/types/errors.rs b/src/types/errors.rs new file mode 100644 index 00000000..f3842b44 --- /dev/null +++ b/src/types/errors.rs @@ -0,0 +1,11 @@ +#[derive(thiserror::Error, Debug)] +pub enum TypeError { + #[error("invalid logical type")] + InvalidLogicalType, + #[error("not implemented arrow datatype: {0}")] + NotImplementedArrowDataType(String), + #[error("not implemented sqlparser datatype: {0}")] + NotImplementedSqlparserDataType(String), + #[error("internal error: {0}")] + InternalError(String), +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 73543e3d..42cbf3e1 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,49 +1,15 @@ +mod errors; pub mod value; - -use integer_encoding::FixedInt; -pub use sqlparser::ast::DataType as DataTypeKind; use std::sync::atomic::AtomicU32; use std::sync::atomic::Ordering::{Acquire, Release}; -static ID_BUF: AtomicU32 = AtomicU32::new(0); - -/// Inner data type -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct DataType { - kind: DataTypeKind, - nullable: bool, -} - -impl DataType { - #[inline] - pub const fn new(kind: DataTypeKind, nullable: bool) -> DataType { - DataType { kind, nullable } - } - #[inline] - pub fn is_nullable(&self) -> bool { - self.nullable - } - #[inline] - pub fn kind(&self) -> DataTypeKind { - self.kind.clone() - } -} +use arrow::datatypes::IntervalUnit; +use integer_encoding::FixedInt; +use strum_macros::AsRefStr; -pub trait DataTypeExt { - fn nullable(self) -> DataType; - fn not_null(self) -> DataType; -} +use crate::types::errors::TypeError; -impl DataTypeExt for DataTypeKind { - #[inline] - fn nullable(self) -> DataType { - DataType::new(self, true) - } - #[inline] - fn not_null(self) -> DataType { - DataType::new(self, false) - } -} +static ID_BUF: AtomicU32 = AtomicU32::new(0); pub(crate) struct IdGenerator {} @@ -68,11 +34,320 @@ impl IdGenerator { pub type TableId = u32; pub type ColumnId = u32; +/// Sqlrs type conversion: +/// sqlparser::ast::DataType -> LogicalType -> arrow::datatypes::DataType +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, AsRefStr)] +pub enum LogicalType { + Invalid, + SqlNull, + Boolean, + Tinyint, + UTinyint, + Smallint, + USmallint, + Integer, + UInteger, + Bigint, + UBigint, + Float, + Double, + Varchar, + Date, + Interval(IntervalUnit), +} + +impl LogicalType { + pub fn numeric() -> Vec { + vec![ + LogicalType::Tinyint, + LogicalType::UTinyint, + LogicalType::Smallint, + LogicalType::USmallint, + LogicalType::Integer, + LogicalType::UInteger, + LogicalType::Bigint, + LogicalType::UBigint, + LogicalType::Float, + LogicalType::Double, + ] + } + + pub fn is_numeric(&self) -> bool { + matches!( + self, + LogicalType::Tinyint + | LogicalType::UTinyint + | LogicalType::Smallint + | LogicalType::USmallint + | LogicalType::Integer + | LogicalType::UInteger + | LogicalType::Bigint + | LogicalType::UBigint + | LogicalType::Float + | LogicalType::Double + ) + } + + pub fn is_signed_numeric(&self) -> bool { + matches!( + self, + LogicalType::Tinyint + | LogicalType::Smallint + | LogicalType::Integer + | LogicalType::Bigint + ) + } + + pub fn is_unsigned_numeric(&self) -> bool { + matches!( + self, + LogicalType::UTinyint + | LogicalType::USmallint + | LogicalType::UInteger + | LogicalType::UBigint + ) + } + + pub fn max_logical_type( + left: &LogicalType, + right: &LogicalType, + ) -> Result { + if left == right { + return Ok(left.clone()); + } + match (left, right) { + // SqlNull type can be cast to anything + (LogicalType::SqlNull, _) => return Ok(right.clone()), + (_, LogicalType::SqlNull) => return Ok(left.clone()), + _ => {} + } + if left.is_numeric() && right.is_numeric() { + return LogicalType::combine_numeric_types(left, right); + } + Err(TypeError::InternalError(format!( + "can not compare two types: {:?} and {:?}", + left, right + ))) + } + + fn combine_numeric_types( + left: &LogicalType, + right: &LogicalType, + ) -> Result { + if left == right { + return Ok(left.clone()); + } + if left.is_signed_numeric() && right.is_unsigned_numeric() { + // this method is symmetric + // arrange it so the left type is smaller + // to limit the number of options we need to check + return LogicalType::combine_numeric_types(right, left); + } + + if LogicalType::can_implicit_cast(left, right) { + return Ok(right.clone()); + } + if LogicalType::can_implicit_cast(right, left) { + return Ok(left.clone()); + } + // we can't cast implicitly either way and types are not equal + // this happens when left is signed and right is unsigned + // e.g. INTEGER and UINTEGER + // in this case we need to upcast to make sure the types fit + match (left, right) { + (LogicalType::Bigint, _) | (_, LogicalType::UBigint) => Ok(LogicalType::Double), + (LogicalType::Integer, _) | (_, LogicalType::UInteger) => Ok(LogicalType::Bigint), + (LogicalType::Smallint, _) | (_, LogicalType::USmallint) => Ok(LogicalType::Integer), + (LogicalType::Tinyint, _) | (_, LogicalType::UTinyint) => Ok(LogicalType::Smallint), + _ => Err(TypeError::InternalError(format!( + "can not combine these numeric types {:?} and {:?}", + left, right + ))), + } + } + + pub fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool { + if from == to { + return true; + } + match from { + LogicalType::Invalid => false, + LogicalType::SqlNull => true, + LogicalType::Boolean => false, + LogicalType::Tinyint => matches!( + to, + LogicalType::Smallint + | LogicalType::Integer + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::UTinyint => matches!( + to, + LogicalType::USmallint + | LogicalType::UInteger + | LogicalType::UBigint + | LogicalType::Smallint + | LogicalType::Integer + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::Smallint => matches!( + to, + LogicalType::Integer + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::USmallint => matches!( + to, + LogicalType::UInteger + | LogicalType::UBigint + | LogicalType::Integer + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::Integer => matches!( + to, + LogicalType::Bigint | LogicalType::Float | LogicalType::Double + ), + LogicalType::UInteger => matches!( + to, + LogicalType::UBigint + | LogicalType::Bigint + | LogicalType::Float + | LogicalType::Double + ), + LogicalType::Bigint => matches!(to, LogicalType::Float | LogicalType::Double), + LogicalType::UBigint => matches!(to, LogicalType::Float | LogicalType::Double), + LogicalType::Float => matches!(to, LogicalType::Double), + LogicalType::Double => false, + LogicalType::Varchar => false, + LogicalType::Date => false, + LogicalType::Interval(_) => false, + } + } +} + +/// sqlparser datatype to logical type +impl TryFrom for LogicalType { + type Error = TypeError; + + fn try_from(value: sqlparser::ast::DataType) -> Result { + match value { + sqlparser::ast::DataType::Char(_) + | sqlparser::ast::DataType::Varchar(_) + | sqlparser::ast::DataType::Nvarchar(_) + | sqlparser::ast::DataType::Text + | sqlparser::ast::DataType::String => Ok(LogicalType::Varchar), + sqlparser::ast::DataType::Float(_) => Ok(LogicalType::Float), + sqlparser::ast::DataType::Double => Ok(LogicalType::Double), + sqlparser::ast::DataType::TinyInt(_) => Ok(LogicalType::Tinyint), + sqlparser::ast::DataType::UnsignedTinyInt(_) => Ok(LogicalType::UTinyint), + sqlparser::ast::DataType::SmallInt(_) => Ok(LogicalType::Smallint), + sqlparser::ast::DataType::UnsignedSmallInt(_) => Ok(LogicalType::USmallint), + sqlparser::ast::DataType::Int(_) | sqlparser::ast::DataType::Integer(_) => { + Ok(LogicalType::Integer) + } + sqlparser::ast::DataType::UnsignedInt(_) + | sqlparser::ast::DataType::UnsignedInteger(_) => Ok(LogicalType::UInteger), + sqlparser::ast::DataType::BigInt(_) => Ok(LogicalType::Bigint), + sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint), + sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), + sqlparser::ast::DataType::Date => Ok(LogicalType::Date), + // use day time interval for default interval value + sqlparser::ast::DataType::Interval => Ok(LogicalType::Interval(IntervalUnit::DayTime)), + other => Err(TypeError::NotImplementedSqlparserDataType( + other.to_string(), + )), + } + } +} + +impl From for arrow::datatypes::DataType { + fn from(value: LogicalType) -> Self { + use arrow::datatypes::DataType; + match value { + LogicalType::Invalid => panic!("invalid logical type"), + LogicalType::SqlNull => DataType::Null, + LogicalType::Boolean => DataType::Boolean, + LogicalType::Tinyint => DataType::Int8, + LogicalType::UTinyint => DataType::UInt8, + LogicalType::Smallint => DataType::Int16, + LogicalType::USmallint => DataType::UInt16, + LogicalType::Integer => DataType::Int32, + LogicalType::UInteger => DataType::UInt32, + LogicalType::Bigint => DataType::Int64, + LogicalType::UBigint => DataType::UInt64, + LogicalType::Float => DataType::Float32, + LogicalType::Double => DataType::Float64, + LogicalType::Varchar => DataType::Utf8, + LogicalType::Date => DataType::Date32, + LogicalType::Interval(u) => DataType::Interval(u), + } + } +} + +impl TryFrom<&arrow::datatypes::DataType> for LogicalType { + type Error = TypeError; + + fn try_from(value: &arrow::datatypes::DataType) -> Result { + use arrow::datatypes::DataType; + Ok(match value { + DataType::Null => LogicalType::SqlNull, + DataType::Boolean => LogicalType::Boolean, + DataType::Int8 => LogicalType::Tinyint, + DataType::Int16 => LogicalType::Smallint, + DataType::Int32 => LogicalType::Integer, + DataType::Int64 => LogicalType::Bigint, + DataType::UInt8 => LogicalType::UTinyint, + DataType::UInt16 => LogicalType::USmallint, + DataType::UInt32 => LogicalType::UInteger, + DataType::UInt64 => LogicalType::UBigint, + DataType::Float16 => LogicalType::Float, + DataType::Float32 => LogicalType::Float, + DataType::Float64 => LogicalType::Double, + DataType::Utf8 => LogicalType::Varchar, + DataType::LargeUtf8 => LogicalType::Varchar, + DataType::Date32 => LogicalType::Date, + DataType::Interval(u) => LogicalType::Interval(u.clone()), + DataType::Timestamp(_, _) + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::Struct(_) + | DataType::Union(_, _, _) + | DataType::Dictionary(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Map(_, _) => { + return Err(TypeError::NotImplementedArrowDataType(value.to_string())) + } + }) + } +} + +impl std::fmt::Display for LogicalType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_ref()) + } +} + #[cfg(test)] mod test { - use crate::types::{IdGenerator, ID_BUF}; use std::sync::atomic::Ordering::Release; + use crate::types::{IdGenerator, ID_BUF}; + /// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰 #[test] #[ignore] diff --git a/src/types/value.rs b/src/types/value.rs index a54be347..e819edfd 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -1,21 +1,605 @@ -use super::DataType; +use std::cmp::Ordering; +use std::fmt; +use std::hash::Hash; +use std::iter::repeat; +use std::sync::Arc; -#[derive(Debug, PartialEq, Clone)] +use arrow::array::{ + new_null_array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder, Date32Array, + Date32Builder, Float32Array, Float32Builder, Float64Array, Float64Builder, Int16Array, + Int16Builder, Int32Array, Int32Builder, Int64Array, Int64Builder, Int8Array, Int8Builder, + IntervalDayTimeArray, IntervalDayTimeBuilder, IntervalMonthDayNanoBuilder, + IntervalYearMonthArray, IntervalYearMonthBuilder, StringArray, StringBuilder, UInt16Array, + UInt16Builder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, UInt8Array, + UInt8Builder, +}; +use arrow::datatypes::{DataType, IntervalUnit}; +use ordered_float::OrderedFloat; + +use super::{LogicalType, TypeError}; + +#[derive(Clone)] pub enum DataValue { Null, - Bool(bool), - Int32(i32), - Int64(i64), - Float64(f64), - String(String), - // Blob(Blob), - // Decimal(Decimal), - // Date(Date), - // Interval(Interval), + Boolean(Option), + Float32(Option), + Float64(Option), + Int8(Option), + Int16(Option), + Int32(Option), + Int64(Option), + UInt8(Option), + UInt16(Option), + UInt32(Option), + UInt64(Option), + Utf8(Option), + /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 + Date32(Option), + /// Number of elapsed whole months + IntervalYearMonth(Option), + /// Number of elapsed days and milliseconds (no leap seconds) + /// stored as 2 contiguous 32-bit signed integers + IntervalDayTime(Option), +} + +impl PartialEq for DataValue { + fn eq(&self, other: &Self) -> bool { + use DataValue::*; + match (self, other) { + (Boolean(v1), Boolean(v2)) => v1.eq(v2), + (Boolean(_), _) => false, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float32(_), _) => false, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.eq(&v2) + } + (Float64(_), _) => false, + (Int8(v1), Int8(v2)) => v1.eq(v2), + (Int8(_), _) => false, + (Int16(v1), Int16(v2)) => v1.eq(v2), + (Int16(_), _) => false, + (Int32(v1), Int32(v2)) => v1.eq(v2), + (Int32(_), _) => false, + (Int64(v1), Int64(v2)) => v1.eq(v2), + (Int64(_), _) => false, + (UInt8(v1), UInt8(v2)) => v1.eq(v2), + (UInt8(_), _) => false, + (UInt16(v1), UInt16(v2)) => v1.eq(v2), + (UInt16(_), _) => false, + (UInt32(v1), UInt32(v2)) => v1.eq(v2), + (UInt32(_), _) => false, + (UInt64(v1), UInt64(v2)) => v1.eq(v2), + (UInt64(_), _) => false, + (Utf8(v1), Utf8(v2)) => v1.eq(v2), + (Utf8(_), _) => false, + (Null, Null) => true, + (Null, _) => false, + (Date32(v1), Date32(v2)) => v1.eq(v2), + (Date32(_), _) => false, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), + (IntervalYearMonth(_), _) => false, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), + (IntervalDayTime(_), _) => false, + } + } +} + +impl PartialOrd for DataValue { + fn partial_cmp(&self, other: &Self) -> Option { + use DataValue::*; + match (self, other) { + (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), + (Boolean(_), _) => None, + (Float32(v1), Float32(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float32(_), _) => None, + (Float64(v1), Float64(v2)) => { + let v1 = v1.map(OrderedFloat); + let v2 = v2.map(OrderedFloat); + v1.partial_cmp(&v2) + } + (Float64(_), _) => None, + (Int8(v1), Int8(v2)) => v1.partial_cmp(v2), + (Int8(_), _) => None, + (Int16(v1), Int16(v2)) => v1.partial_cmp(v2), + (Int16(_), _) => None, + (Int32(v1), Int32(v2)) => v1.partial_cmp(v2), + (Int32(_), _) => None, + (Int64(v1), Int64(v2)) => v1.partial_cmp(v2), + (Int64(_), _) => None, + (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2), + (UInt8(_), _) => None, + (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2), + (UInt16(_), _) => None, + (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2), + (UInt32(_), _) => None, + (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2), + (UInt64(_), _) => None, + (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), + (Utf8(_), _) => None, + (Null, Null) => Some(Ordering::Equal), + (Null, _) => None, + (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), + (Date32(_), _) => None, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(_), _) => None, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), + (IntervalDayTime(_), _) => None, + } + } +} + +impl Eq for DataValue {} + +impl Hash for DataValue { + fn hash(&self, state: &mut H) { + use DataValue::*; + match self { + Boolean(v) => v.hash(state), + Float32(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Float64(v) => { + let v = v.map(OrderedFloat); + v.hash(state) + } + Int8(v) => v.hash(state), + Int16(v) => v.hash(state), + Int32(v) => v.hash(state), + Int64(v) => v.hash(state), + UInt8(v) => v.hash(state), + UInt16(v) => v.hash(state), + UInt32(v) => v.hash(state), + UInt64(v) => v.hash(state), + Utf8(v) => v.hash(state), + Null => 1.hash(state), + Date32(v) => v.hash(state), + IntervalYearMonth(v) => v.hash(state), + IntervalDayTime(v) => v.hash(state), + } + } +} + +macro_rules! typed_cast { + ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{ + let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap(); + DataValue::$SCALAR(match array.is_null($index) { + true => None, + false => Some(array.value($index).into()), + }) + }}; +} + +macro_rules! build_array_from_option { + ($DATA_TYPE:ident, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)), + None => new_null_array(&DataType::$DATA_TYPE, $SIZE), + } + }}; + ($DATA_TYPE:ident, $ENUM:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)), + None => new_null_array(&DataType::$DATA_TYPE($ENUM), $SIZE), + } + }}; + ($DATA_TYPE:ident, $ENUM:expr, $ENUM2:expr, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{ + match $EXPR { + Some(value) => { + let array: ArrayRef = Arc::new($ARRAY_TYPE::from_value(*value, $SIZE)); + // Need to call cast to cast to final data type with timezone/extra param + cast(&array, &DataType::$DATA_TYPE($ENUM, $ENUM2)).expect("cannot do temporal cast") + } + None => new_null_array(&DataType::$DATA_TYPE($ENUM, $ENUM2), $SIZE), + } + }}; } + impl DataValue { - /// Get the type of value. `None` means NULL. - pub fn data_type(&self) -> Option { - todo!() + pub fn new_none_value(data_type: &DataType) -> Result { + match data_type { + DataType::Null => Ok(DataValue::Null), + DataType::Boolean => Ok(DataValue::Boolean(None)), + DataType::Float32 => Ok(DataValue::Float32(None)), + DataType::Float64 => Ok(DataValue::Float64(None)), + DataType::Int8 => Ok(DataValue::Int8(None)), + DataType::Int16 => Ok(DataValue::Int16(None)), + DataType::Int32 => Ok(DataValue::Int32(None)), + DataType::Int64 => Ok(DataValue::Int64(None)), + DataType::UInt8 => Ok(DataValue::UInt8(None)), + DataType::UInt16 => Ok(DataValue::UInt16(None)), + DataType::UInt32 => Ok(DataValue::UInt32(None)), + DataType::UInt64 => Ok(DataValue::UInt64(None)), + DataType::Utf8 => Ok(DataValue::Utf8(None)), + other => Err(TypeError::NotImplementedArrowDataType(other.to_string())), + } + } + + /// Converts a value in `array` at `index` into a DataValue + pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { + if !array.is_valid(index) { + return Self::new_none_value(array.data_type()); + } + + use arrow::array::*; + + Ok(match array.data_type() { + DataType::Null => DataValue::Null, + DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), + DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), + DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), + DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64), + DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32), + DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16), + DataType::UInt8 => typed_cast!(array, index, UInt8Array, UInt8), + DataType::Int64 => typed_cast!(array, index, Int64Array, Int64), + DataType::Int32 => typed_cast!(array, index, Int32Array, Int32), + DataType::Int16 => typed_cast!(array, index, Int16Array, Int16), + DataType::Int8 => typed_cast!(array, index, Int8Array, Int8), + DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), + other => { + return Err(TypeError::NotImplementedArrowDataType(other.to_string())); + } + }) + } + + pub fn get_logical_type(&self) -> LogicalType { + match self { + DataValue::Null => LogicalType::SqlNull, + DataValue::Boolean(_) => LogicalType::Boolean, + DataValue::Float32(_) => LogicalType::Float, + DataValue::Float64(_) => LogicalType::Double, + DataValue::Int8(_) => LogicalType::Tinyint, + DataValue::Int16(_) => LogicalType::Smallint, + DataValue::Int32(_) => LogicalType::Integer, + DataValue::Int64(_) => LogicalType::Bigint, + DataValue::UInt8(_) => LogicalType::UTinyint, + DataValue::UInt16(_) => LogicalType::USmallint, + DataValue::UInt32(_) => LogicalType::UInteger, + DataValue::UInt64(_) => LogicalType::UBigint, + DataValue::Utf8(_) => LogicalType::Varchar, + DataValue::Date32(_) => LogicalType::Date, + DataValue::IntervalYearMonth(_) => LogicalType::Interval(IntervalUnit::YearMonth), + DataValue::IntervalDayTime(_) => LogicalType::Interval(IntervalUnit::DayTime), + } + } + + /// Converts a scalar value into an 1-row array. + pub fn to_array(&self) -> ArrayRef { + self.to_array_of_size(1) + } + + /// Converts a scalar value into an array of `size` rows. + pub fn to_array_of_size(&self, size: usize) -> ArrayRef { + match self { + DataValue::Boolean(e) => Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef, + DataValue::Float64(e) => { + build_array_from_option!(Float64, Float64Array, e, size) + } + DataValue::Float32(e) => { + build_array_from_option!(Float32, Float32Array, e, size) + } + DataValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size), + DataValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size), + DataValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size), + DataValue::Int64(e) => build_array_from_option!(Int64, Int64Array, e, size), + DataValue::UInt8(e) => build_array_from_option!(UInt8, UInt8Array, e, size), + DataValue::UInt16(e) => { + build_array_from_option!(UInt16, UInt16Array, e, size) + } + DataValue::UInt32(e) => { + build_array_from_option!(UInt32, UInt32Array, e, size) + } + DataValue::UInt64(e) => { + build_array_from_option!(UInt64, UInt64Array, e, size) + } + + DataValue::Utf8(e) => match e { + Some(value) => Arc::new(StringArray::from_iter_values(repeat(value).take(size))), + None => new_null_array(&DataType::Utf8, size), + }, + DataValue::Null => new_null_array(&DataType::Null, size), + DataValue::Date32(e) => { + build_array_from_option!(Date32, Date32Array, e, size) + } + DataValue::IntervalDayTime(e) => build_array_from_option!( + Interval, + IntervalUnit::DayTime, + IntervalDayTimeArray, + e, + size + ), + DataValue::IntervalYearMonth(e) => build_array_from_option!( + Interval, + IntervalUnit::YearMonth, + IntervalYearMonthArray, + e, + size + ), + } + } + + pub fn new_builder(data_type: &LogicalType) -> Result, TypeError> { + match data_type { + LogicalType::Invalid | LogicalType::SqlNull => Err(TypeError::InternalError(format!( + "Unsupported type {:?} for builder", + data_type + ))), + LogicalType::Boolean => Ok(Box::new(BooleanBuilder::new())), + LogicalType::Tinyint => Ok(Box::new(Int8Builder::new())), + LogicalType::UTinyint => Ok(Box::new(UInt8Builder::new())), + LogicalType::Smallint => Ok(Box::new(Int16Builder::new())), + LogicalType::USmallint => Ok(Box::new(UInt16Builder::new())), + LogicalType::Integer => Ok(Box::new(Int32Builder::new())), + LogicalType::UInteger => Ok(Box::new(UInt32Builder::new())), + LogicalType::Bigint => Ok(Box::new(Int64Builder::new())), + LogicalType::UBigint => Ok(Box::new(UInt64Builder::new())), + LogicalType::Float => Ok(Box::new(Float32Builder::new())), + LogicalType::Double => Ok(Box::new(Float64Builder::new())), + LogicalType::Varchar => Ok(Box::new(StringBuilder::new())), + LogicalType::Date => Ok(Box::new(Date32Builder::new())), + LogicalType::Interval(IntervalUnit::DayTime) => { + Ok(Box::new(IntervalDayTimeBuilder::new())) + } + LogicalType::Interval(IntervalUnit::YearMonth) => { + Ok(Box::new(IntervalYearMonthBuilder::new())) + } + LogicalType::Interval(IntervalUnit::MonthDayNano) => { + Ok(Box::new(IntervalMonthDayNanoBuilder::new())) + } + } + } + + pub fn append_for_builder( + value: &DataValue, + builder: &mut Box, + ) -> Result<(), TypeError> { + match value { + DataValue::Null => { + return Err(TypeError::InternalError( + "Unsupported type: Null for builder".to_string(), + )) + } + DataValue::Boolean(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::Utf8(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(v.as_ref()), + DataValue::Int8(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::Int16(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::Int32(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::Int64(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::UInt8(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::UInt16(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::UInt32(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::UInt64(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::Float32(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::Float64(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::Date32(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::IntervalYearMonth(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + DataValue::IntervalDayTime(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + } + Ok(()) + } + + pub fn get_logic_type(&self) -> LogicalType { + match self { + DataValue::Boolean(_) => LogicalType::Boolean, + DataValue::UInt8(_) => LogicalType::UTinyint, + DataValue::UInt16(_) => LogicalType::USmallint, + DataValue::UInt32(_) => LogicalType::UInteger, + DataValue::UInt64(_) => LogicalType::UBigint, + DataValue::Int8(_) => LogicalType::Tinyint, + DataValue::Int16(_) => LogicalType::Smallint, + DataValue::Int32(_) => LogicalType::Integer, + DataValue::Int64(_) => LogicalType::Bigint, + DataValue::Float32(_) => LogicalType::Float, + DataValue::Float64(_) => LogicalType::Double, + DataValue::Utf8(_) => LogicalType::Varchar, + DataValue::Null => LogicalType::Invalid, + DataValue::Date32(_) => LogicalType::Date, + DataValue::IntervalYearMonth(_) => LogicalType::Interval(IntervalUnit::YearMonth), + DataValue::IntervalDayTime(_) => LogicalType::Interval(IntervalUnit::DayTime), + } + } + + pub fn get_datatype(&self) -> DataType { + match self { + DataValue::Boolean(_) => DataType::Boolean, + DataValue::UInt8(_) => DataType::UInt8, + DataValue::UInt16(_) => DataType::UInt16, + DataValue::UInt32(_) => DataType::UInt32, + DataValue::UInt64(_) => DataType::UInt64, + DataValue::Int8(_) => DataType::Int8, + DataValue::Int16(_) => DataType::Int16, + DataValue::Int32(_) => DataType::Int32, + DataValue::Int64(_) => DataType::Int64, + DataValue::Float32(_) => DataType::Float32, + DataValue::Float64(_) => DataType::Float64, + DataValue::Utf8(_) => DataType::Utf8, + DataValue::Null => DataType::Null, + DataValue::Date32(_) => DataType::Date32, + DataValue::IntervalYearMonth(_) => DataType::Interval(IntervalUnit::YearMonth), + DataValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), + } + } +} + +macro_rules! impl_scalar { + ($ty:ty, $scalar:tt) => { + impl From<$ty> for DataValue { + fn from(value: $ty) -> Self { + DataValue::$scalar(Some(value)) + } + } + + impl From> for DataValue { + fn from(value: Option<$ty>) -> Self { + DataValue::$scalar(value) + } + } + }; +} + +impl_scalar!(f64, Float64); +impl_scalar!(f32, Float32); +impl_scalar!(i8, Int8); +impl_scalar!(i16, Int16); +impl_scalar!(i32, Int32); +impl_scalar!(i64, Int64); +impl_scalar!(bool, Boolean); +impl_scalar!(u8, UInt8); +impl_scalar!(u16, UInt16); +impl_scalar!(u32, UInt32); +impl_scalar!(u64, UInt64); +impl_scalar!(String, Utf8); + +impl From<&sqlparser::ast::Value> for DataValue { + fn from(v: &sqlparser::ast::Value) -> Self { + match v { + sqlparser::ast::Value::Number(n, _) => { + // use i32 to handle most cases + if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else if let Ok(v) = n.parse::() { + v.into() + } else { + panic!("unsupported number {:?}", n) + } + } + sqlparser::ast::Value::SingleQuotedString(s) => s.clone().into(), + sqlparser::ast::Value::DoubleQuotedString(s) => s.clone().into(), + sqlparser::ast::Value::Boolean(b) => (*b).into(), + sqlparser::ast::Value::Null => Self::Null, + _ => todo!("unsupported parsed scalar value {:?}", v), + } + } +} + +macro_rules! format_option { + ($F:expr, $EXPR:expr) => {{ + match $EXPR { + Some(e) => write!($F, "{}", e), + None => write!($F, "NULL"), + } + }}; +} + +impl fmt::Display for DataValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DataValue::Boolean(e) => format_option!(f, e)?, + DataValue::Float32(e) => format_option!(f, e)?, + DataValue::Float64(e) => format_option!(f, e)?, + DataValue::Int8(e) => format_option!(f, e)?, + DataValue::Int16(e) => format_option!(f, e)?, + DataValue::Int32(e) => format_option!(f, e)?, + DataValue::Int64(e) => format_option!(f, e)?, + DataValue::UInt8(e) => format_option!(f, e)?, + DataValue::UInt16(e) => format_option!(f, e)?, + DataValue::UInt32(e) => format_option!(f, e)?, + DataValue::UInt64(e) => format_option!(f, e)?, + DataValue::Utf8(e) => format_option!(f, e)?, + DataValue::Null => write!(f, "NULL")?, + DataValue::Date32(e) => format_option!(f, e)?, + DataValue::IntervalDayTime(e) => format_option!(f, e)?, + DataValue::IntervalYearMonth(e) => format_option!(f, e)?, + }; + Ok(()) + } +} + +impl fmt::Debug for DataValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DataValue::Boolean(_) => write!(f, "Boolean({})", self), + DataValue::Float32(_) => write!(f, "Float32({})", self), + DataValue::Float64(_) => write!(f, "Float64({})", self), + DataValue::Int8(_) => write!(f, "Int8({})", self), + DataValue::Int16(_) => write!(f, "Int16({})", self), + DataValue::Int32(_) => write!(f, "Int32({})", self), + DataValue::Int64(_) => write!(f, "Int64({})", self), + DataValue::UInt8(_) => write!(f, "UInt8({})", self), + DataValue::UInt16(_) => write!(f, "UInt16({})", self), + DataValue::UInt32(_) => write!(f, "UInt32({})", self), + DataValue::UInt64(_) => write!(f, "UInt64({})", self), + DataValue::Utf8(None) => write!(f, "Utf8({})", self), + DataValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), + DataValue::Null => write!(f, "NULL"), + DataValue::Date32(_) => write!(f, "Date32({})", self), + DataValue::IntervalYearMonth(_) => write!(f, "IntervalYearMonth({})", self), + DataValue::IntervalDayTime(_) => write!(f, "IntervalDayTime({})", self), + } } }