From 252ae5ea37e1564eb00e9c0933b0e0a2f33a3673 Mon Sep 17 00:00:00 2001 From: Kould Date: Thu, 18 Sep 2025 14:33:47 +0800 Subject: [PATCH] refactor: Use `TupleValueSerializable` to simplify DataValue serialization --- src/binder/select.rs | 5 +- src/catalog/table.rs | 6 - src/execution/ddl/add_column.rs | 10 +- src/execution/ddl/drop_column.rs | 4 +- src/execution/dml/copy_from_file.rs | 12 +- src/execution/dml/insert.rs | 7 +- src/execution/dml/update.rs | 9 +- src/execution/dql/index_scan.rs | 2 +- src/execution/{marco.rs => execute_macro.rs} | 0 src/execution/mod.rs | 2 +- src/storage/mod.rs | 175 +++--- src/storage/rocksdb.rs | 90 ++- src/storage/table_codec.rs | 36 +- src/types/mod.rs | 1 + src/types/serialize.rs | 580 +++++++++++++++++++ src/types/tuple.rs | 155 +++-- src/types/value.rs | 287 +-------- 17 files changed, 888 insertions(+), 493 deletions(-) rename src/execution/{marco.rs => execute_macro.rs} (100%) create mode 100644 src/types/serialize.rs diff --git a/src/binder/select.rs b/src/binder/select.rs index bdd6a0a3..4a5e5fdf 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -159,10 +159,7 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<' Ok(plan) } - fn bind_temp_values( - &mut self, - expr_rows: &Vec>, - ) -> Result { + fn bind_temp_values(&mut self, expr_rows: &[Vec]) -> Result { let values_len = expr_rows[0].len(); let mut inferred_types: Vec> = vec![None; values_len]; diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 94ccc45a..4ee9d6a1 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -89,12 +89,6 @@ impl TableCatalog { &self.primary_key_indices } - pub(crate) fn types(&self) -> Vec { - self.columns() - .map(|column| column.datatype().clone()) - .collect_vec() - } - /// Add a column to the table catalog. pub(crate) fn add_column( &mut self, diff --git a/src/execution/ddl/add_column.rs b/src/execution/ddl/add_column.rs index f91ffe3a..9e8c5de9 100644 --- a/src/execution/ddl/add_column.rs +++ b/src/execution/ddl/add_column.rs @@ -9,6 +9,7 @@ use crate::types::value::DataValue; use crate::{ planner::operator::alter_table::add_column::AddColumnOperator, storage::Transaction, throw, }; +use itertools::Itertools; use std::ops::Coroutine; use std::ops::CoroutineState; use std::pin::Pin; @@ -69,9 +70,14 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for AddColumn { } drop(coroutine); + let serializers = types.iter().map(|ty| ty.serializable()).collect_vec(); for tuple in tuples { - throw!(unsafe { &mut (*transaction) } - .append_tuple(table_name, tuple, &types, true)); + throw!(unsafe { &mut (*transaction) }.append_tuple( + table_name, + tuple, + &serializers, + true + )); } let col_id = throw!(unsafe { &mut (*transaction) }.add_column( cache.0, diff --git a/src/execution/ddl/drop_column.rs b/src/execution/ddl/drop_column.rs index 2bb818bb..0c5d5b56 100644 --- a/src/execution/ddl/drop_column.rs +++ b/src/execution/ddl/drop_column.rs @@ -6,6 +6,7 @@ use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; use crate::throw; use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; +use itertools::Itertools; use std::ops::Coroutine; use std::ops::CoroutineState; use std::pin::Pin; @@ -66,11 +67,12 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for DropColumn { tuples.push(tuple); } drop(coroutine); + let serializers = types.iter().map(|ty| ty.serializable()).collect_vec(); for tuple in tuples { throw!(unsafe { &mut (*transaction) }.append_tuple( &table_name, tuple, - &types, + &serializers, true )); } diff --git a/src/execution/dml/copy_from_file.rs b/src/execution/dml/copy_from_file.rs index 8ab3bf66..8833581e 100644 --- a/src/execution/dml/copy_from_file.rs +++ b/src/execution/dml/copy_from_file.rs @@ -5,8 +5,9 @@ use crate::execution::{Executor, WriteExecutor}; use crate::planner::operator::copy_from_file::CopyFromFileOperator; use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; use crate::throw; -use crate::types::tuple::{types, Tuple}; +use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; +use itertools::Itertools; use std::fs::File; use std::io::BufReader; use std::sync::mpsc; @@ -33,7 +34,12 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CopyFromFile { Box::new( #[coroutine] move || { - let types = types(&self.op.schema_ref); + let serializers = self + .op + .schema_ref + .iter() + .map(|column| column.datatype().serializable()) + .collect_vec(); let (tx, rx) = mpsc::channel(); let (tx1, rx1) = mpsc::channel(); // # Cancellation @@ -50,7 +56,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for CopyFromFile { throw!(unsafe { &mut (*transaction) }.append_tuple( table.name(), chunk, - &types, + &serializers, false )); size += 1; diff --git a/src/execution/dml/insert.rs b/src/execution/dml/insert.rs index 15c79bd5..ede0e838 100644 --- a/src/execution/dml/insert.rs +++ b/src/execution/dml/insert.rs @@ -99,7 +99,10 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { index_metas.push((index_meta, exprs)); } - let types = table_catalog.types(); + let serializers = table_catalog + .columns() + .map(|column| column.datatype().serializable()) + .collect_vec(); let pk_indices = table_catalog.primary_keys_indices(); let mut coroutine = build_read(input, cache, transaction); @@ -147,7 +150,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Insert { throw!(unsafe { &mut (*transaction) }.append_tuple( &table_name, tuple, - &types, + &serializers, is_overwrite )); inserted_count += 1; diff --git a/src/execution/dml/update.rs b/src/execution/dml/update.rs index ee60c351..b63b795a 100644 --- a/src/execution/dml/update.rs +++ b/src/execution/dml/update.rs @@ -8,10 +8,10 @@ use crate::planner::LogicalPlan; use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache}; use crate::throw; use crate::types::index::Index; -use crate::types::tuple::types; use crate::types::tuple::Tuple; use crate::types::tuple_builder::TupleBuilder; use crate::types::value::DataValue; +use itertools::Itertools; use std::collections::HashMap; use std::ops::Coroutine; use std::ops::CoroutineState; @@ -62,7 +62,10 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Update { } let input_schema = input.output_schema().clone(); - let types = types(&input_schema); + let serializers = input_schema + .iter() + .map(|column| column.datatype().serializable()) + .collect_vec(); let mut updated_count = 0; if let Some(table_catalog) = @@ -133,7 +136,7 @@ impl<'a, T: Transaction + 'a> WriteExecutor<'a, T> for Update { throw!(unsafe { &mut (*transaction) }.append_tuple( &table_name, tuple, - &types, + &serializers, is_overwrite )); updated_count += 1; diff --git a/src/execution/dql/index_scan.rs b/src/execution/dql/index_scan.rs index 8844db53..607302ef 100644 --- a/src/execution/dql/index_scan.rs +++ b/src/execution/dql/index_scan.rs @@ -50,7 +50,7 @@ impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for IndexScan { columns, self.index_by, self.ranges, - with_pk, + with_pk )); while let Some(tuple) = throw!(iter.next_tuple()) { diff --git a/src/execution/marco.rs b/src/execution/execute_macro.rs similarity index 100% rename from src/execution/marco.rs rename to src/execution/execute_macro.rs diff --git a/src/execution/mod.rs b/src/execution/mod.rs index 202a878d..0685d9d7 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -1,7 +1,7 @@ pub(crate) mod ddl; pub(crate) mod dml; pub(crate) mod dql; -pub(crate) mod marco; +pub(crate) mod execute_macro; use self::ddl::add_column::AddColumn; use self::dql::join::nested_loop_join::NestedLoopJoin; diff --git a/src/storage/mod.rs b/src/storage/mod.rs index f40f43c5..bee66192 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -10,6 +10,7 @@ use crate::optimizer::core::statistics_meta::{StatisticMetaLoader, StatisticsMet use crate::serdes::ReferenceTables; use crate::storage::table_codec::{BumpBytes, Bytes, TableCodec}; use crate::types::index::{Index, IndexId, IndexMetaRef, IndexType}; +use crate::types::serialize::TupleValueSerializableImpl; use crate::types::tuple::{Tuple, TupleId}; use crate::types::value::DataValue; use crate::types::{ColumnId, LogicalType}; @@ -61,19 +62,13 @@ pub trait Transaction: Sized { let table = self .table(table_cache, table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; - let table_types = table.types(); if columns.is_empty() || with_pk { for (i, column) in table.primary_keys() { columns.insert(*i, column.clone()); } } - let mut tuple_columns = Vec::with_capacity(columns.len()); - let mut projections = Vec::with_capacity(columns.len()); - for (projection, column) in columns { - tuple_columns.push(column); - projections.push(projection); - } - let remap_pk_indices = remap_pk_indices(&projections, table.primary_keys_indices()); + let (deserializers, remap_pk_indices) = + Self::create_deserializers(&columns, table, with_pk); let (min, max) = unsafe { &*self.table_codec() }.tuple_bound(&table_name); let iter = self.range(Bound::Included(min), Bound::Included(max))?; @@ -81,11 +76,10 @@ pub trait Transaction: Sized { Ok(TupleIter { offset: bounds.0.unwrap_or(0), limit: bounds.1, - table_types, - tuple_columns: Arc::new(tuple_columns), remap_pk_indices, - projections, - with_pk, + deserializers, + values_len: columns.len(), + total_len: table.columns_len(), iter, }) } @@ -106,7 +100,6 @@ pub trait Transaction: Sized { let table = self .table(table_cache, table_name.clone())? .ok_or(DatabaseError::TableNotFound)?; - let table_types = table.types(); let table_name = table.name.as_str(); let offset = offset_option.unwrap_or(0); @@ -115,13 +108,8 @@ pub trait Transaction: Sized { columns.insert(*i, column.clone()); } } - let mut tuple_columns = Vec::with_capacity(columns.len()); - let mut projections = Vec::with_capacity(columns.len()); - for (projection, column) in columns { - tuple_columns.push(column); - projections.push(projection); - } - let remap_pk_indices = remap_pk_indices(&projections, table.primary_keys_indices()); + let (deserializers, remap_pk_indices) = + Self::create_deserializers(&columns, table, with_pk); let inner = IndexImplEnum::instance(index_meta.ty); Ok(IndexIter { @@ -129,12 +117,11 @@ pub trait Transaction: Sized { limit: limit_option, remap_pk_indices, params: IndexImplParams { - tuple_schema_ref: Arc::new(tuple_columns), - projections, index_meta, table_name, - table_types, - with_pk, + deserializers, + values_len: columns.len(), + total_len: table.columns_len(), tx: self, }, inner, @@ -143,6 +130,42 @@ pub trait Transaction: Sized { }) } + fn create_deserializers( + columns: &BTreeMap, + table: &TableCatalog, + with_pk: bool, + ) -> (Vec, Option>) { + let primary_keys_indices = table.primary_keys_indices(); + + let mut deserializers = Vec::with_capacity(columns.len()); + let mut projections = Vec::with_capacity(columns.len()); + let mut last_projection = None; + for (projection, column) in columns.iter() { + let (start, end) = last_projection + .map(|last_projection| { + let start = last_projection + 1; + let len = projection - start; + (start, start + len) + }) + .unwrap_or((0, *projection)); + for skip_column in table.schema_ref()[start..end].iter() { + deserializers.push(skip_column.datatype().skip_serializable()); + } + if with_pk { + projections.push(*projection); + } + deserializers.push(column.datatype().serializable()); + last_projection = Some(*projection); + } + let remap_pk_indices = with_pk.then(|| { + primary_keys_indices + .iter() + .filter_map(|pk| projections.binary_search(pk).ok()) + .collect_vec() + }); + (deserializers, remap_pk_indices) + } + fn add_index_meta( &mut self, table_cache: &TableCache, @@ -212,11 +235,11 @@ pub trait Transaction: Sized { &mut self, table_name: &str, mut tuple: Tuple, - types: &[LogicalType], + serializers: &[TupleValueSerializableImpl], is_overwrite: bool, ) -> Result<(), DatabaseError> { let (key, value) = - unsafe { &*self.table_codec() }.encode_tuple(table_name, &mut tuple, types)?; + unsafe { &*self.table_codec() }.encode_tuple(table_name, &mut tuple, serializers)?; if !is_overwrite && self.get(&key)?.is_some() { return Err(DatabaseError::DuplicatePrimaryKey); @@ -727,14 +750,14 @@ trait IndexImpl<'bytes, T: Transaction + 'bytes> { fn index_lookup( &self, bytes: &Bytes, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams, ) -> Result; fn eq_to_res<'a>( &self, value: &DataValue, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams<'a, T>, ) -> Result, DatabaseError>; @@ -770,13 +793,11 @@ struct NormalIndexImpl; struct CompositeIndexImpl; struct IndexImplParams<'a, T: Transaction> { - tuple_schema_ref: Arc>, - projections: Vec, - index_meta: IndexMetaRef, table_name: &'a str, - table_types: Vec, - with_pk: bool, + deserializers: Vec, + values_len: usize, + total_len: usize, tx: &'a T, } @@ -802,7 +823,7 @@ impl IndexImplParams<'_, T> { fn get_tuple_by_id( &self, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, tuple_id: &TupleId, ) -> Result, DatabaseError> { let key = unsafe { &*self.table_codec() }.encode_tuple_key(self.table_name, tuple_id)?; @@ -811,12 +832,11 @@ impl IndexImplParams<'_, T> { .get(&key)? .map(|bytes| { TableCodec::decode_tuple( - &self.table_types, + &self.deserializers, pk_indices, - &self.projections, - &self.tuple_schema_ref, &bytes, - self.with_pk, + self.values_len, + self.total_len, ) }) .transpose() @@ -832,7 +852,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for IndexImplEnum { fn index_lookup( &self, bytes: &Bytes, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams, ) -> Result { match self { @@ -846,7 +866,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for IndexImplEnum { fn eq_to_res<'a>( &self, value: &DataValue, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams<'a, T>, ) -> Result, DatabaseError> { match self { @@ -876,23 +896,22 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for PrimaryKeyIndexIm fn index_lookup( &self, bytes: &Bytes, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams, ) -> Result { TableCodec::decode_tuple( - ¶ms.table_types, + ¶ms.deserializers, pk_indices, - ¶ms.projections, - ¶ms.tuple_schema_ref, bytes, - params.with_pk, + params.values_len, + params.total_len, ) } fn eq_to_res<'a>( &self, value: &DataValue, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams<'a, T>, ) -> Result, DatabaseError> { let tuple = params @@ -900,12 +919,11 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for PrimaryKeyIndexIm .get(&unsafe { &*params.table_codec() }.encode_tuple_key(params.table_name, value)?)? .map(|bytes| { TableCodec::decode_tuple( - ¶ms.table_types, + ¶ms.deserializers, pk_indices, - ¶ms.projections, - ¶ms.tuple_schema_ref, &bytes, - params.with_pk, + params.values_len, + params.total_len, ) }) .transpose()?; @@ -924,7 +942,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for PrimaryKeyIndexIm fn secondary_index_lookup( bytes: &Bytes, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams, ) -> Result { let tuple_id = TableCodec::decode_index(bytes)?; @@ -937,7 +955,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for UniqueIndexImpl { fn index_lookup( &self, bytes: &Bytes, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams, ) -> Result { secondary_index_lookup(bytes, pk_indices, params) @@ -946,7 +964,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for UniqueIndexImpl { fn eq_to_res<'a>( &self, value: &DataValue, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams<'a, T>, ) -> Result, DatabaseError> { let Some(bytes) = params.tx.get(&self.bound_key(params, value, false)?)? else { @@ -975,7 +993,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for NormalIndexImpl { fn index_lookup( &self, bytes: &Bytes, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams, ) -> Result { secondary_index_lookup(bytes, pk_indices, params) @@ -984,7 +1002,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for NormalIndexImpl { fn eq_to_res<'a>( &self, value: &DataValue, - _: &[usize], + _: Option<&[usize]>, params: &IndexImplParams<'a, T>, ) -> Result, DatabaseError> { let min = self.bound_key(params, value, false)?; @@ -1016,7 +1034,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for CompositeIndexImp fn index_lookup( &self, bytes: &Bytes, - pk_indices: &[usize], + pk_indices: Option<&[usize]>, params: &IndexImplParams, ) -> Result { secondary_index_lookup(bytes, pk_indices, params) @@ -1025,7 +1043,7 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for CompositeIndexImp fn eq_to_res<'a>( &self, value: &DataValue, - _: &[usize], + _: Option<&[usize]>, params: &IndexImplParams<'a, T>, ) -> Result, DatabaseError> { let min = self.bound_key(params, value, false)?; @@ -1056,11 +1074,10 @@ impl<'bytes, T: Transaction + 'bytes> IndexImpl<'bytes, T> for CompositeIndexImp pub struct TupleIter<'a, T: Transaction + 'a> { offset: usize, limit: Option, - table_types: Vec, - tuple_columns: Arc>, - remap_pk_indices: Vec, - projections: Vec, - with_pk: bool, + remap_pk_indices: Option>, + deserializers: Vec, + values_len: usize, + total_len: usize, iter: T::IterType<'a>, } @@ -1082,12 +1099,11 @@ impl<'a, T: Transaction + 'a> Iter for TupleIter<'a, T> { *limit -= 1; } let tuple = TableCodec::decode_tuple( - &self.table_types, - &self.remap_pk_indices, - &self.projections, - &self.tuple_columns, + &self.deserializers, + self.remap_pk_indices.as_deref(), &value, - self.with_pk, + self.values_len, + self.total_len, )?; return Ok(Some(tuple)); @@ -1101,7 +1117,7 @@ pub struct IndexIter<'a, T: Transaction> { offset: usize, limit: Option, - remap_pk_indices: Vec, + remap_pk_indices: Option>, params: IndexImplParams<'a, T>, inner: IndexImplEnum, // for buffering data @@ -1205,7 +1221,7 @@ impl Iter for IndexIter<'_, T> { match self.inner.eq_to_res( &val, - &self.remap_pk_indices, + self.remap_pk_indices.as_deref(), &self.params, )? { IndexResult::Tuple(tuple) => { @@ -1231,7 +1247,7 @@ impl Iter for IndexIter<'_, T> { Self::limit_sub(&mut self.limit); let tuple = self.inner.index_lookup( &bytes, - &self.remap_pk_indices, + self.remap_pk_indices.as_deref(), &self.params, )?; @@ -1253,13 +1269,6 @@ pub trait Iter { fn next_tuple(&mut self) -> Result, DatabaseError>; } -fn remap_pk_indices(projection: &[usize], pk_indices: &[usize]) -> Vec { - pk_indices - .iter() - .filter_map(|pk| projection.binary_search(pk).ok()) - .collect() -} - #[cfg(test)] mod test { use crate::binder::test::build_t1_table; @@ -1459,9 +1468,9 @@ mod test { "t1", tuple, &[ - LogicalType::Integer, - LogicalType::Boolean, - LogicalType::Integer, + LogicalType::Integer.serializable(), + LogicalType::Boolean.serializable(), + LogicalType::Integer.serializable(), ], false, )?; @@ -1696,9 +1705,9 @@ mod test { "t1", tuple, &[ - LogicalType::Integer, - LogicalType::Boolean, - LogicalType::Integer, + LogicalType::Integer.serializable(), + LogicalType::Boolean.serializable(), + LogicalType::Integer.serializable(), ], false, )?; diff --git a/src/storage/rocksdb.rs b/src/storage/rocksdb.rs index 770e85f3..16c9b496 100644 --- a/src/storage/rocksdb.rs +++ b/src/storage/rocksdb.rs @@ -286,7 +286,10 @@ mod test { Some(DataValue::Int32(1)), vec![DataValue::Int32(1), DataValue::Boolean(true)], ), - &[LogicalType::Integer, LogicalType::Boolean], + &[ + LogicalType::Integer.serializable(), + LogicalType::Boolean.serializable(), + ], false, )?; transaction.append_tuple( @@ -295,7 +298,10 @@ mod test { Some(DataValue::Int32(2)), vec![DataValue::Int32(2), DataValue::Boolean(true)], ), - &[LogicalType::Integer, LogicalType::Boolean], + &[ + LogicalType::Integer.serializable(), + LogicalType::Boolean.serializable(), + ], false, )?; @@ -344,13 +350,17 @@ mod test { DataValue::Int32(3), DataValue::Int32(4), ]; + let deserializers = table + .columns() + .map(|column| column.datatype().serializable()) + .collect_vec(); + let values_len = deserializers.len(); let mut iter = IndexIter { offset: 0, limit: None, - remap_pk_indices: vec![0], + remap_pk_indices: Some(vec![0]), params: IndexImplParams { - tuple_schema_ref: table.schema_ref().clone(), - projections: vec![0], + deserializers, index_meta: Arc::new(IndexMeta { id: 0, column_ids: vec![*a_column_id], @@ -361,9 +371,9 @@ mod test { ty: IndexType::PrimaryKey { is_multiple: false }, }), table_name: &table.name, - table_types: table.types(), - with_pk: true, tx: &transaction, + values_len, + total_len: 1, }, ranges: vec![ Range::Eq(DataValue::Int32(0)), @@ -395,7 +405,7 @@ mod test { .run("create table t1 (a int primary key, b int unique)")? .done()?; kite_sql - .run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2)")? + .run("insert into t1 (a, b) values (0, 0), (1, 1), (2, 2), (3, 4)")? .done()?; let transaction = kite_sql.storage.transaction().unwrap(); @@ -403,25 +413,51 @@ mod test { .table(kite_sql.state.table_cache(), Arc::new("t1".to_string()))? .unwrap() .clone(); - let columns = table.columns().cloned().enumerate().collect(); - let mut iter = transaction - .read_by_index( - kite_sql.state.table_cache(), - Arc::new("t1".to_string()), - (Some(0), Some(1)), - columns, - table.indexes[0].clone(), - vec![Range::Scope { - min: Bound::Excluded(DataValue::Int32(0)), - max: Bound::Unbounded, - }], - true, - ) - .unwrap(); - - while let Some(tuple) = iter.next_tuple()? { - assert_eq!(tuple.pk, Some(DataValue::Int32(1))); - assert_eq!(tuple.values, vec![DataValue::Int32(1), DataValue::Int32(1)]) + { + let mut iter = transaction + .read_by_index( + kite_sql.state.table_cache(), + Arc::new("t1".to_string()), + (Some(0), Some(1)), + table.columns().cloned().enumerate().collect(), + table.indexes[0].clone(), + vec![Range::Scope { + min: Bound::Excluded(DataValue::Int32(0)), + max: Bound::Unbounded, + }], + true, + ) + .unwrap(); + + while let Some(tuple) = iter.next_tuple()? { + assert_eq!(tuple.pk, Some(DataValue::Int32(1))); + assert_eq!(tuple.values, vec![DataValue::Int32(1), DataValue::Int32(1)]) + } + } + // projection + { + let mut columns: BTreeMap<_, _> = table.columns().cloned().enumerate().collect(); + let _ = columns.pop_last(); + + let mut iter = transaction + .read_by_index( + kite_sql.state.table_cache(), + Arc::new("t1".to_string()), + (Some(0), Some(1)), + columns, + table.indexes[0].clone(), + vec![Range::Scope { + min: Bound::Excluded(DataValue::Int32(3)), + max: Bound::Unbounded, + }], + true, + ) + .unwrap(); + + while let Some(tuple) = iter.next_tuple()? { + assert_eq!(tuple.pk, Some(DataValue::Int32(3))); + assert_eq!(tuple.values, vec![DataValue::Int32(3)]) + } } Ok(()) diff --git a/src/storage/table_codec.rs b/src/storage/table_codec.rs index 3d99c0f0..91fca64e 100644 --- a/src/storage/table_codec.rs +++ b/src/storage/table_codec.rs @@ -4,7 +4,8 @@ use crate::errors::DatabaseError; use crate::serdes::{ReferenceSerialization, ReferenceTables}; use crate::storage::{TableCache, Transaction}; use crate::types::index::{Index, IndexId, IndexMeta, IndexType}; -use crate::types::tuple::{Schema, Tuple, TupleId}; +use crate::types::serialize::TupleValueSerializableImpl; +use crate::types::tuple::{Tuple, TupleId}; use crate::types::value::DataValue; use crate::types::LogicalType; use bumpalo::Bump; @@ -258,7 +259,7 @@ impl TableCodec { &self, table_name: &str, tuple: &mut Tuple, - types: &[LogicalType], + types: &[TupleValueSerializableImpl], ) -> Result<(BumpBytes, BumpBytes), DatabaseError> { let tuple_id = tuple.pk.as_ref().ok_or(DatabaseError::PrimaryKeyNotFound)?; let key = self.encode_tuple_key(table_name, tuple_id)?; @@ -283,14 +284,13 @@ impl TableCodec { #[inline] pub fn decode_tuple( - table_types: &[LogicalType], - pk_indices: &[usize], - projections: &[usize], - schema: &Schema, + deserializers: &[TupleValueSerializableImpl], + pk_indices: Option<&[usize]>, bytes: &[u8], - with_pk: bool, + values_len: usize, + total_len: usize, ) -> Result { - Tuple::deserialize_from(table_types, pk_indices, projections, schema, bytes, with_pk) + Tuple::deserialize_from(deserializers, pk_indices, bytes, values_len, total_len) } pub fn encode_index_meta_key( @@ -584,21 +584,19 @@ mod tests { let (_, bytes) = table_codec.encode_tuple( &table_catalog.name, &mut tuple, - &[LogicalType::Integer, LogicalType::Decimal(None, None)], + &[ + LogicalType::Integer.serializable(), + LogicalType::Decimal(None, None).serializable(), + ], )?; - let schema = table_catalog.schema_ref(); - let pk_indices = table_catalog.primary_keys_indices(); + let deserializers = table_catalog + .columns() + .map(|column| column.datatype().serializable()) + .collect_vec(); tuple.pk = None; assert_eq!( - TableCodec::decode_tuple( - &table_catalog.types(), - pk_indices, - &[0, 1], - schema, - &bytes, - false - )?, + TableCodec::decode_tuple(&deserializers, None, &bytes, deserializers.len(), 2,)?, tuple ); diff --git a/src/types/mod.rs b/src/types/mod.rs index 7e92ca37..fedb69f9 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,5 +1,6 @@ pub mod evaluator; pub mod index; +pub mod serialize; pub mod tuple; pub mod tuple_builder; pub mod value; diff --git a/src/types/serialize.rs b/src/types/serialize.rs new file mode 100644 index 00000000..ec207b92 --- /dev/null +++ b/src/types/serialize.rs @@ -0,0 +1,580 @@ +use crate::errors::DatabaseError; +use crate::types::value::{DataValue, Utf8Type}; +use crate::types::LogicalType; +use bumpalo::collections::Vec; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use ordered_float::OrderedFloat; +use rust_decimal::Decimal; +use sqlparser::ast::CharLengthUnits; +use std::fmt::Debug; +use std::io::{Cursor, Read, Seek, SeekFrom, Write}; + +macro_rules! impl_tuple_value_serializable { + ($name:ident, $variant:ident, $write_fn:expr, $read_fn:expr) => { + impl TupleValueSerializable for $name { + fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + let DataValue::$variant(v) = value else { + unsafe { std::hint::unreachable_unchecked() } + }; + ($write_fn)(writer, v)?; + Ok(()) + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + Ok(DataValue::$variant(($read_fn)(reader)?)) + } + } + }; +} + +pub trait TupleValueSerializable: Debug { + fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError>; + #[allow(clippy::wrong_self_convention)] + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result; + fn filling_value( + &self, + reader: &mut Cursor<&[u8]>, + values: &mut std::vec::Vec, + ) -> Result<(), DatabaseError> { + values.push(self.from_raw(reader)?); + Ok(()) + } +} + +#[derive(Debug)] +pub enum TupleValueSerializableImpl { + Boolean, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float32, + Float64, + Char { + len: u32, + unit: CharLengthUnits, + }, + Varchar { + len: Option, + unit: CharLengthUnits, + }, + Date, + DateTime, + Time { + precision: Option, + }, + Timestamp { + precision: Option, + zone: bool, + }, + Decimal, + SkipFixed(usize), + SkipVariable, +} + +impl TupleValueSerializable for TupleValueSerializableImpl { + fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + match self { + TupleValueSerializableImpl::Boolean => BooleanSerializable.to_raw(value, writer), + TupleValueSerializableImpl::Int8 => Int8Serializable.to_raw(value, writer), + TupleValueSerializableImpl::Int16 => Int16Serializable.to_raw(value, writer), + TupleValueSerializableImpl::Int32 => Int32Serializable.to_raw(value, writer), + TupleValueSerializableImpl::Int64 => Int64Serializable.to_raw(value, writer), + TupleValueSerializableImpl::UInt8 => UInt8Serializable.to_raw(value, writer), + TupleValueSerializableImpl::UInt16 => UInt16Serializable.to_raw(value, writer), + TupleValueSerializableImpl::UInt32 => UInt32Serializable.to_raw(value, writer), + TupleValueSerializableImpl::UInt64 => UInt64Serializable.to_raw(value, writer), + TupleValueSerializableImpl::Float32 => Float32Serializable.to_raw(value, writer), + TupleValueSerializableImpl::Float64 => Float64Serializable.to_raw(value, writer), + TupleValueSerializableImpl::Char { len, unit } => CharSerializable { + len: *len, + unit: *unit, + } + .to_raw(value, writer), + TupleValueSerializableImpl::Varchar { len, unit } => VarcharSerializable { + len: *len, + unit: *unit, + } + .to_raw(value, writer), + TupleValueSerializableImpl::Date => DateSerializable.to_raw(value, writer), + TupleValueSerializableImpl::DateTime => DateTimeSerializable.to_raw(value, writer), + TupleValueSerializableImpl::Time { precision } => TimeSerializable { + precision: *precision, + } + .to_raw(value, writer), + TupleValueSerializableImpl::Timestamp { precision, zone } => TimeStampSerializable { + precision: *precision, + zone: *zone, + } + .to_raw(value, writer), + TupleValueSerializableImpl::Decimal => DecimalSerializable.to_raw(value, writer), + TupleValueSerializableImpl::SkipFixed(len) => SkipFixed(*len).to_raw(value, writer), + TupleValueSerializableImpl::SkipVariable => SkipVariable.to_raw(value, writer), + } + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + match self { + TupleValueSerializableImpl::Boolean => BooleanSerializable.from_raw(reader), + TupleValueSerializableImpl::Int8 => Int8Serializable.from_raw(reader), + TupleValueSerializableImpl::Int16 => Int16Serializable.from_raw(reader), + TupleValueSerializableImpl::Int32 => Int32Serializable.from_raw(reader), + TupleValueSerializableImpl::Int64 => Int64Serializable.from_raw(reader), + TupleValueSerializableImpl::UInt8 => UInt8Serializable.from_raw(reader), + TupleValueSerializableImpl::UInt16 => UInt16Serializable.from_raw(reader), + TupleValueSerializableImpl::UInt32 => UInt32Serializable.from_raw(reader), + TupleValueSerializableImpl::UInt64 => UInt64Serializable.from_raw(reader), + TupleValueSerializableImpl::Float32 => Float32Serializable.from_raw(reader), + TupleValueSerializableImpl::Float64 => Float64Serializable.from_raw(reader), + TupleValueSerializableImpl::Char { len, unit } => CharSerializable { + len: *len, + unit: *unit, + } + .from_raw(reader), + TupleValueSerializableImpl::Varchar { len, unit } => VarcharSerializable { + len: *len, + unit: *unit, + } + .from_raw(reader), + TupleValueSerializableImpl::Date => DateSerializable.from_raw(reader), + TupleValueSerializableImpl::DateTime => DateTimeSerializable.from_raw(reader), + TupleValueSerializableImpl::Time { precision } => TimeSerializable { + precision: *precision, + } + .from_raw(reader), + TupleValueSerializableImpl::Timestamp { precision, zone } => TimeStampSerializable { + precision: *precision, + zone: *zone, + } + .from_raw(reader), + TupleValueSerializableImpl::Decimal => DecimalSerializable.from_raw(reader), + TupleValueSerializableImpl::SkipFixed(len) => SkipFixed(*len).from_raw(reader), + TupleValueSerializableImpl::SkipVariable => SkipVariable.from_raw(reader), + } + } + + fn filling_value( + &self, + reader: &mut Cursor<&[u8]>, + values: &mut std::vec::Vec, + ) -> Result<(), DatabaseError> { + match self { + TupleValueSerializableImpl::Boolean => { + BooleanSerializable.filling_value(reader, values) + } + TupleValueSerializableImpl::Int8 => Int8Serializable.filling_value(reader, values), + TupleValueSerializableImpl::Int16 => Int16Serializable.filling_value(reader, values), + TupleValueSerializableImpl::Int32 => Int32Serializable.filling_value(reader, values), + TupleValueSerializableImpl::Int64 => Int64Serializable.filling_value(reader, values), + TupleValueSerializableImpl::UInt8 => UInt8Serializable.filling_value(reader, values), + TupleValueSerializableImpl::UInt16 => UInt16Serializable.filling_value(reader, values), + TupleValueSerializableImpl::UInt32 => UInt32Serializable.filling_value(reader, values), + TupleValueSerializableImpl::UInt64 => UInt64Serializable.filling_value(reader, values), + TupleValueSerializableImpl::Float32 => { + Float32Serializable.filling_value(reader, values) + } + TupleValueSerializableImpl::Float64 => { + Float64Serializable.filling_value(reader, values) + } + TupleValueSerializableImpl::Char { len, unit } => CharSerializable { + len: *len, + unit: *unit, + } + .filling_value(reader, values), + TupleValueSerializableImpl::Varchar { len, unit } => VarcharSerializable { + len: *len, + unit: *unit, + } + .filling_value(reader, values), + TupleValueSerializableImpl::Date => DateSerializable.filling_value(reader, values), + TupleValueSerializableImpl::DateTime => { + DateTimeSerializable.filling_value(reader, values) + } + TupleValueSerializableImpl::Time { precision } => TimeSerializable { + precision: *precision, + } + .filling_value(reader, values), + TupleValueSerializableImpl::Timestamp { precision, zone } => TimeStampSerializable { + precision: *precision, + zone: *zone, + } + .filling_value(reader, values), + TupleValueSerializableImpl::Decimal => { + DecimalSerializable.filling_value(reader, values) + } + TupleValueSerializableImpl::SkipFixed(len) => { + SkipFixed(*len).filling_value(reader, values) + } + TupleValueSerializableImpl::SkipVariable => SkipVariable.filling_value(reader, values), + } + } +} + +#[derive(Debug)] +struct BooleanSerializable; + +#[derive(Debug)] +struct Int8Serializable; +#[derive(Debug)] +struct Int16Serializable; +#[derive(Debug)] +struct Int32Serializable; +#[derive(Debug)] +struct Int64Serializable; + +#[derive(Debug)] +struct UInt8Serializable; +#[derive(Debug)] +struct UInt16Serializable; +#[derive(Debug)] +struct UInt32Serializable; +#[derive(Debug)] +struct UInt64Serializable; + +#[derive(Debug)] +struct Float32Serializable; +#[derive(Debug)] +struct Float64Serializable; + +#[derive(Debug)] +struct CharSerializable { + len: u32, + unit: CharLengthUnits, +} +#[derive(Debug)] +struct VarcharSerializable { + len: Option, + unit: CharLengthUnits, +} + +#[derive(Debug)] +struct DateSerializable; +#[derive(Debug)] +struct DateTimeSerializable; + +#[derive(Debug)] +struct TimeSerializable { + precision: Option, +} +#[derive(Debug)] +struct TimeStampSerializable { + precision: Option, + zone: bool, +} + +#[derive(Debug)] +struct DecimalSerializable; + +#[derive(Debug)] +struct SkipFixed(usize); +#[derive(Debug)] +struct SkipVariable; + +// Int +impl_tuple_value_serializable!( + Int8Serializable, + Int8, + |writer: &mut Vec, &value| writer.write_i8(value), + |reader: &mut Cursor<&[u8]>| reader.read_i8() +); +impl_tuple_value_serializable!( + Int16Serializable, + Int16, + |writer: &mut Vec, &value| writer.write_i16::(value), + |reader: &mut Cursor<&[u8]>| reader.read_i16::() +); +impl_tuple_value_serializable!( + Int32Serializable, + Int32, + |writer: &mut Vec, &value| writer.write_i32::(value), + |reader: &mut Cursor<&[u8]>| reader.read_i32::() +); +impl_tuple_value_serializable!( + Int64Serializable, + Int64, + |writer: &mut Vec, &value| writer.write_i64::(value), + |reader: &mut Cursor<&[u8]>| reader.read_i64::() +); + +// Uint +impl_tuple_value_serializable!( + UInt8Serializable, + UInt8, + |writer: &mut Vec, &value| writer.write_u8(value), + |reader: &mut Cursor<&[u8]>| reader.read_u8() +); +impl_tuple_value_serializable!( + UInt16Serializable, + UInt16, + |writer: &mut Vec, &value| writer.write_u16::(value), + |reader: &mut Cursor<&[u8]>| reader.read_u16::() +); +impl_tuple_value_serializable!( + UInt32Serializable, + UInt32, + |writer: &mut Vec, &value| writer.write_u32::(value), + |reader: &mut Cursor<&[u8]>| reader.read_u32::() +); +impl_tuple_value_serializable!( + UInt64Serializable, + UInt64, + |writer: &mut Vec, &value| writer.write_u64::(value), + |reader: &mut Cursor<&[u8]>| reader.read_u64::() +); + +// Float +impl_tuple_value_serializable!( + Float32Serializable, + Float32, + |writer: &mut Vec, value: &OrderedFloat::| writer + .write_f32::(value.into_inner()), + |reader: &mut Cursor<&[u8]>| reader.read_f32::().map(OrderedFloat::) +); +impl_tuple_value_serializable!( + Float64Serializable, + Float64, + |writer: &mut Vec, value: &OrderedFloat::| writer + .write_f64::(value.into_inner()), + |reader: &mut Cursor<&[u8]>| reader.read_f64::().map(OrderedFloat::) +); + +impl_tuple_value_serializable!( + BooleanSerializable, + Boolean, + |writer: &mut Vec, &value| writer.write_u8(value as u8), + |reader: &mut Cursor<&[u8]>| reader.read_u8().map(|v| v != 0) +); + +impl TupleValueSerializable for CharSerializable { + fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + let DataValue::Utf8 { + value, + unit, + ty: Utf8Type::Fixed(len), + } = value + else { + unsafe { std::hint::unreachable_unchecked() } + }; + match unit { + CharLengthUnits::Characters => { + let chars_len = *len as usize; + let v = format!("{:len$}", value, len = chars_len); + let bytes = v.as_bytes(); + + writer.write_u32::(bytes.len() as u32)?; + writer.write_all(bytes)?; + } + CharLengthUnits::Octets => { + let octets_len = *len as usize; + let bytes = value.as_bytes(); + debug_assert!(octets_len >= bytes.len()); + + writer.write_all(bytes)?; + for _ in 0..octets_len - bytes.len() { + writer.write_u8(b' ')?; + } + } + } + Ok(()) + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + // https://dev.mysql.com/doc/refman/8.0/en/char.html#:~:text=If%20a%20given%20value%20is%20stored%20into%20the%20CHAR(4)%20and%20VARCHAR(4)%20columns%2C%20the%20values%20retrieved%20from%20the%20columns%20are%20not%20always%20the%20same%20because%20trailing%20spaces%20are%20removed%20from%20CHAR%20columns%20upon%20retrieval.%20The%20following%20example%20illustrates%20this%20difference%3A + let len = match self.unit { + CharLengthUnits::Characters => reader.read_u32::()?, + CharLengthUnits::Octets => self.len, + } as usize; + let mut bytes = vec![0; len]; + reader.read_exact(&mut bytes)?; + let last_non_zero_index = match bytes.iter().rposition(|&x| x != b' ') { + Some(index) => index + 1, + None => 0, + }; + bytes.truncate(last_non_zero_index); + + Ok(DataValue::Utf8 { + value: String::from_utf8(bytes)?, + ty: Utf8Type::Fixed(self.len), + unit: self.unit, + }) + } +} + +impl TupleValueSerializable for VarcharSerializable { + fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + let DataValue::Utf8 { + value, + ty: Utf8Type::Variable(_), + .. + } = value + else { + unsafe { std::hint::unreachable_unchecked() } + }; + let bytes = value.as_bytes(); + + writer.write_u32::(bytes.len() as u32)?; + writer.write_all(bytes)?; + Ok(()) + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + let len = reader.read_u32::()? as usize; + let mut bytes = vec![0; len]; + reader.read_exact(&mut bytes)?; + + Ok(DataValue::Utf8 { + value: String::from_utf8(bytes)?, + ty: Utf8Type::Variable(self.len), + unit: self.unit, + }) + } +} + +impl_tuple_value_serializable!( + DateSerializable, + Date32, + |writer: &mut Vec, &value| writer.write_i32::(value), + |reader: &mut Cursor<&[u8]>| reader.read_i32::() +); +impl_tuple_value_serializable!( + DateTimeSerializable, + Date64, + |writer: &mut Vec, &value| writer.write_i64::(value), + |reader: &mut Cursor<&[u8]>| reader.read_i64::() +); + +impl TupleValueSerializable for TimeSerializable { + fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + let DataValue::Time32(v, ..) = value else { + unsafe { std::hint::unreachable_unchecked() } + }; + writer.write_u32::(*v)?; + Ok(()) + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + let precision = self.precision.unwrap_or_default(); + Ok(DataValue::Time32( + reader.read_u32::()?, + precision, + )) + } +} + +impl TupleValueSerializable for TimeStampSerializable { + fn to_raw(&self, value: &DataValue, writer: &mut Vec) -> Result<(), DatabaseError> { + let DataValue::Time64(v, ..) = value else { + unsafe { std::hint::unreachable_unchecked() } + }; + writer.write_i64::(*v)?; + Ok(()) + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + let precision = self.precision.unwrap_or_default(); + Ok(DataValue::Time64( + reader.read_i64::()?, + precision, + self.zone, + )) + } +} + +impl_tuple_value_serializable!( + DecimalSerializable, + Decimal, + |writer: &mut Vec, &value: &Decimal| writer.write_all(&value.serialize()), + |reader: &mut Cursor<&[u8]>| { + let mut bytes = [0u8; 16]; + reader.read_exact(&mut bytes)?; + Result::<_, DatabaseError>::Ok(Decimal::deserialize(bytes)) + } +); + +impl TupleValueSerializable for SkipFixed { + fn to_raw(&self, _: &DataValue, _: &mut Vec) -> Result<(), DatabaseError> { + unreachable!(); + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + reader.seek(SeekFrom::Current(self.0 as i64))?; + Ok(DataValue::Null) + } + + fn filling_value( + &self, + reader: &mut Cursor<&[u8]>, + _: &mut std::vec::Vec, + ) -> Result<(), DatabaseError> { + let _ = self.from_raw(reader)?; + Ok(()) + } +} + +impl TupleValueSerializable for SkipVariable { + fn to_raw(&self, _: &DataValue, _: &mut Vec) -> Result<(), DatabaseError> { + unreachable!(); + } + + fn from_raw(&self, reader: &mut Cursor<&[u8]>) -> Result { + let len = reader.read_u32::()? as usize; + reader.seek(SeekFrom::Current(len as i64))?; + Ok(DataValue::Null) + } + + fn filling_value( + &self, + reader: &mut Cursor<&[u8]>, + _: &mut std::vec::Vec, + ) -> Result<(), DatabaseError> { + let _ = self.from_raw(reader)?; + Ok(()) + } +} + +impl LogicalType { + pub fn skip_serializable(&self) -> TupleValueSerializableImpl { + self.raw_len() + .map(TupleValueSerializableImpl::SkipFixed) + .unwrap_or(TupleValueSerializableImpl::SkipVariable) + } + + pub fn serializable(&self) -> TupleValueSerializableImpl { + match self { + LogicalType::Boolean => TupleValueSerializableImpl::Boolean, + LogicalType::Tinyint => TupleValueSerializableImpl::Int8, + LogicalType::UTinyint => TupleValueSerializableImpl::UInt8, + LogicalType::Smallint => TupleValueSerializableImpl::Int16, + LogicalType::USmallint => TupleValueSerializableImpl::UInt16, + LogicalType::Integer => TupleValueSerializableImpl::Int32, + LogicalType::UInteger => TupleValueSerializableImpl::UInt32, + LogicalType::Bigint => TupleValueSerializableImpl::Int64, + LogicalType::UBigint => TupleValueSerializableImpl::UInt64, + LogicalType::Float => TupleValueSerializableImpl::Float32, + LogicalType::Double => TupleValueSerializableImpl::Float64, + LogicalType::Char(len, unit) => TupleValueSerializableImpl::Char { + len: *len, + unit: *unit, + }, + LogicalType::Varchar(len, unit) => TupleValueSerializableImpl::Varchar { + len: *len, + unit: *unit, + }, + LogicalType::Date => TupleValueSerializableImpl::Date, + LogicalType::DateTime => TupleValueSerializableImpl::DateTime, + LogicalType::Time(precision) => TupleValueSerializableImpl::Time { + precision: *precision, + }, + LogicalType::TimeStamp(precision, zone) => TupleValueSerializableImpl::Timestamp { + precision: *precision, + zone: *zone, + }, + LogicalType::Decimal(_, _) => TupleValueSerializableImpl::Decimal, + LogicalType::SqlNull | LogicalType::Tuple(_) => unreachable!(), + } + } +} diff --git a/src/types/tuple.rs b/src/types/tuple.rs index c1ba7bdc..af2850a3 100644 --- a/src/types/tuple.rs +++ b/src/types/tuple.rs @@ -2,8 +2,8 @@ use crate::catalog::ColumnRef; use crate::db::ResultIter; use crate::errors::DatabaseError; use crate::storage::table_codec::BumpBytes; +use crate::types::serialize::{TupleValueSerializable, TupleValueSerializableImpl}; use crate::types::value::DataValue; -use crate::types::LogicalType; use bumpalo::Bump; use comfy_table::{Cell, Table}; use itertools::Itertools; @@ -16,13 +16,6 @@ pub type TupleId = DataValue; pub type Schema = Vec; pub type SchemaRef = Arc; -pub fn types(schema: &Schema) -> Vec { - schema - .iter() - .map(|column| column.datatype().clone()) - .collect_vec() -} - #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Tuple { pub pk: Option, @@ -36,47 +29,31 @@ impl Tuple { #[inline] pub fn deserialize_from( - table_types: &[LogicalType], - pk_indices: &[usize], - projections: &[usize], - schema: &Schema, + deserializers: &[TupleValueSerializableImpl], + pk_indices: Option<&[usize]>, bytes: &[u8], - with_pk: bool, + values_len: usize, + total_len: usize, ) -> Result { - debug_assert!(!schema.is_empty()); - debug_assert!(projections.is_sorted()); - debug_assert_eq!(projections.len(), schema.len()); - - fn is_none(bits: u8, i: usize) -> bool { + fn is_null(bits: u8, i: usize) -> bool { bits & (1 << (7 - i)) > 0 } - let types_len = table_types.len(); - let bits_len = (types_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; - let mut values = vec![DataValue::Null; projections.len()]; + let bits_len = (total_len + BITS_MAX_INDEX) / BITS_MAX_INDEX; + let mut values = Vec::with_capacity(values_len); - let mut projection_i = 0; let mut cursor = Cursor::new(&bytes[bits_len..]); - for (i, logic_type) in table_types.iter().enumerate() { - if projections.len() <= projection_i { - break; - } - debug_assert!(projection_i < types_len); - if is_none(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { - projection_i += 1; + for (i, deserializer) in deserializers.iter().enumerate() { + if is_null(bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX) { + values.push(DataValue::Null); continue; } - if let Some(value) = - DataValue::from_raw(&mut cursor, logic_type, projections[projection_i] == i)? - { - values[projection_i] = value; - projection_i += 1; - } + deserializer.filling_value(&mut cursor, &mut values)?; } Ok(Tuple { - pk: with_pk.then(|| Tuple::primary_projection(pk_indices, &values)), + pk: pk_indices.map(|pk_indices| Tuple::primary_projection(pk_indices, &values)), values, }) } @@ -85,10 +62,10 @@ impl Tuple { /// Tips: all len is u32 pub fn serialize_to<'a>( &self, - types: &[LogicalType], + serializers: &[TupleValueSerializableImpl], arena: &'a Bump, ) -> Result, DatabaseError> { - debug_assert_eq!(self.values.len(), types.len()); + debug_assert_eq!(self.values.len(), serializers.len()); fn flip_bit(bits: u8, i: usize) -> u8 { bits | (1 << (7 - i)) @@ -99,15 +76,15 @@ impl Tuple { let mut bytes = BumpBytes::new_in(arena); bytes.resize(bits_len, 0u8); let null_bytes: *mut BumpBytes = &mut bytes; - let mut value_bytes = &mut bytes; - for (i, value) in self.values.iter().enumerate() { + debug_assert_eq!(self.values.len(), serializers.len()); + for (i, (value, serializer)) in self.values.iter().zip(serializers.iter()).enumerate() { if value.is_null() { let null_bytes = unsafe { &mut *null_bytes }; null_bytes[i / BITS_MAX_INDEX] = flip_bit(null_bytes[i / BITS_MAX_INDEX], i % BITS_MAX_INDEX); } else { - value.to_raw(&mut value_bytes)?; + serializer.to_raw(value, &mut bytes)?; } } Ok(bytes) @@ -333,20 +310,19 @@ mod tests { ], ), ]; - let types = columns + let serializers = columns .iter() - .map(|column| column.datatype().clone()) + .map(|column| column.datatype().serializable()) .collect_vec(); let columns = Arc::new(columns); let arena = Bump::new(); { let tuple_0 = Tuple::deserialize_from( - &types, - &Arc::new(vec![0]), - &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - &columns, - &tuples[0].serialize_to(&types, &arena).unwrap(), - true, + &serializers, + Some(vec![0]).as_deref(), + &tuples[0].serialize_to(&serializers, &arena).unwrap(), + serializers.len(), + columns.len(), ) .unwrap(); @@ -354,16 +330,85 @@ mod tests { } { let tuple_1 = Tuple::deserialize_from( - &types, - &Arc::new(vec![0]), - &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - &columns, - &tuples[1].serialize_to(&types, &arena).unwrap(), - true, + &serializers, + Some(vec![0]).as_deref(), + &tuples[1].serialize_to(&serializers, &arena).unwrap(), + serializers.len(), + columns.len(), ) .unwrap(); assert_eq!(tuples[1], tuple_1); } + // projection + { + let projection_serializers = vec![ + columns[0].datatype().serializable(), + columns[1].datatype().skip_serializable(), + columns[2].datatype().skip_serializable(), + columns[3].datatype().serializable(), + ]; + let tuple_2 = Tuple::deserialize_from( + &projection_serializers, + Some(vec![0]).as_deref(), + &tuples[0].serialize_to(&serializers, &arena).unwrap(), + 2, + columns.len(), + ) + .unwrap(); + + assert_eq!( + tuple_2, + Tuple { + pk: Some(DataValue::Int32(0)), + values: vec![DataValue::Int32(0), DataValue::Int16(1)], + } + ); + } + // multiple pk + { + let multiple_pk_serializers = columns + .iter() + .take(5) + .map(|column| column.datatype().serializable()) + .collect_vec(); + + let tuple_3 = Tuple::deserialize_from( + &multiple_pk_serializers, + Some(vec![4, 2]).as_deref(), + &tuples[0].serialize_to(&serializers, &arena).unwrap(), + serializers.len(), + columns.len(), + ) + .unwrap(); + + assert_eq!( + tuple_3, + Tuple { + pk: Some(DataValue::Tuple( + vec![ + DataValue::UInt16(1), + DataValue::Utf8 { + value: "LOL".to_string(), + ty: Utf8Type::Variable(Some(2)), + unit: CharLengthUnits::Octets, + }, + ], + false + )), + values: vec![ + DataValue::Int32(0), + DataValue::UInt32(1), + DataValue::Utf8 { + value: "LOL".to_string(), + ty: Utf8Type::Variable(Some(2)), + unit: CharLengthUnits::Characters, + }, + DataValue::Int16(1), + DataValue::UInt16(1), + ], + } + ); + } } } diff --git a/src/types/value.rs b/src/types/value.rs index 1b83053d..bb19254d 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -1,7 +1,6 @@ use super::LogicalType; use crate::errors::DatabaseError; use crate::storage::table_codec::{BumpBytes, BOUND_MAX_TAG, BOUND_MIN_TAG}; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use chrono::format::{DelayedFormat, StrftimeItems}; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use itertools::Itertools; @@ -12,7 +11,7 @@ use sqlparser::ast::CharLengthUnits; use std::cmp::Ordering; use std::fmt::Formatter; use std::hash::Hash; -use std::io::{Read, Seek, SeekFrom, Write}; +use std::io::Write; use std::str::FromStr; use std::sync::LazyLock; use std::{cmp, fmt, mem}; @@ -526,290 +525,6 @@ impl DataValue { } } - #[inline] - pub fn to_raw(&self, writer: &mut W) -> Result<(), DatabaseError> { - match self { - DataValue::Null => (), - DataValue::Boolean(v) => { - writer.write_u8(*v as u8)?; - return Ok(()); - } - DataValue::Float32(v) => { - writer.write_f32::(v.0)?; - return Ok(()); - } - DataValue::Float64(v) => { - writer.write_f64::(v.0)?; - return Ok(()); - } - DataValue::Int8(v) => { - writer.write_i8(*v)?; - return Ok(()); - } - DataValue::Int16(v) => { - writer.write_i16::(*v)?; - return Ok(()); - } - DataValue::Int32(v) => { - writer.write_i32::(*v)?; - return Ok(()); - } - DataValue::Int64(v) => { - writer.write_i64::(*v)?; - return Ok(()); - } - DataValue::UInt8(v) => { - writer.write_u8(*v)?; - return Ok(()); - } - DataValue::UInt16(v) => { - writer.write_u16::(*v)?; - return Ok(()); - } - DataValue::UInt32(v) => { - writer.write_u32::(*v)?; - return Ok(()); - } - DataValue::UInt64(v) => { - writer.write_u64::(*v)?; - return Ok(()); - } - DataValue::Utf8 { value: v, ty, unit } => match ty { - Utf8Type::Variable(_) => { - let bytes = v.as_bytes(); - - writer.write_u32::(bytes.len() as u32)?; - writer.write_all(bytes)?; - return Ok(()); - } - Utf8Type::Fixed(len) => match unit { - CharLengthUnits::Characters => { - let chars_len = *len as usize; - let v = format!("{:len$}", v, len = chars_len); - let bytes = v.as_bytes(); - - writer.write_u32::(bytes.len() as u32)?; - writer.write_all(bytes)?; - return Ok(()); - } - CharLengthUnits::Octets => { - let octets_len = *len as usize; - let bytes = v.as_bytes(); - debug_assert!(octets_len >= bytes.len()); - - writer.write_all(bytes)?; - for _ in 0..octets_len - bytes.len() { - writer.write_u8(b' ')?; - } - return Ok(()); - } - }, - }, - DataValue::Date32(v) => { - writer.write_i32::(*v)?; - return Ok(()); - } - DataValue::Date64(v) => { - writer.write_i64::(*v)?; - return Ok(()); - } - DataValue::Time32(v, ..) => { - writer.write_u32::(*v)?; - return Ok(()); - } - DataValue::Time64(v, ..) => { - writer.write_i64::(*v)?; - return Ok(()); - } - DataValue::Decimal(v) => { - writer.write_all(&v.serialize())?; - return Ok(()); - } - DataValue::Tuple(..) => unreachable!(), - } - Ok(()) - } - - #[inline] - pub fn from_raw( - reader: &mut R, - ty: &LogicalType, - is_projection: bool, - ) -> Result, DatabaseError> { - let value = match ty { - LogicalType::SqlNull => { - if !is_projection { - return Ok(None); - } - DataValue::Null - } - LogicalType::Boolean => { - if !is_projection { - reader.seek(SeekFrom::Current(1))?; - return Ok(None); - } - DataValue::Boolean(reader.read_u8()? != 0) - } - LogicalType::Tinyint => { - if !is_projection { - reader.seek(SeekFrom::Current(1))?; - return Ok(None); - } - DataValue::Int8(reader.read_i8()?) - } - LogicalType::UTinyint => { - if !is_projection { - reader.seek(SeekFrom::Current(1))?; - return Ok(None); - } - DataValue::UInt8(reader.read_u8()?) - } - LogicalType::Smallint => { - if !is_projection { - reader.seek(SeekFrom::Current(2))?; - return Ok(None); - } - DataValue::Int16(reader.read_i16::()?) - } - LogicalType::USmallint => { - if !is_projection { - reader.seek(SeekFrom::Current(2))?; - return Ok(None); - } - DataValue::UInt16(reader.read_u16::()?) - } - LogicalType::Integer => { - if !is_projection { - reader.seek(SeekFrom::Current(4))?; - return Ok(None); - } - DataValue::Int32(reader.read_i32::()?) - } - LogicalType::UInteger => { - if !is_projection { - reader.seek(SeekFrom::Current(4))?; - return Ok(None); - } - DataValue::UInt32(reader.read_u32::()?) - } - LogicalType::Bigint => { - if !is_projection { - reader.seek(SeekFrom::Current(8))?; - return Ok(None); - } - DataValue::Int64(reader.read_i64::()?) - } - LogicalType::UBigint => { - if !is_projection { - reader.seek(SeekFrom::Current(8))?; - return Ok(None); - } - DataValue::UInt64(reader.read_u64::()?) - } - LogicalType::Float => { - if !is_projection { - reader.seek(SeekFrom::Current(4))?; - return Ok(None); - } - DataValue::Float32(OrderedFloat(reader.read_f32::()?)) - } - LogicalType::Double => { - if !is_projection { - reader.seek(SeekFrom::Current(8))?; - return Ok(None); - } - DataValue::Float64(OrderedFloat(reader.read_f64::()?)) - } - LogicalType::Char(ty_len, unit) => { - // https://dev.mysql.com/doc/refman/8.0/en/char.html#:~:text=If%20a%20given%20value%20is%20stored%20into%20the%20CHAR(4)%20and%20VARCHAR(4)%20columns%2C%20the%20values%20retrieved%20from%20the%20columns%20are%20not%20always%20the%20same%20because%20trailing%20spaces%20are%20removed%20from%20CHAR%20columns%20upon%20retrieval.%20The%20following%20example%20illustrates%20this%20difference%3A - let len = match unit { - CharLengthUnits::Characters => reader.read_u32::()?, - CharLengthUnits::Octets => *ty_len, - } as usize; - if !is_projection { - reader.seek(SeekFrom::Current(len as i64))?; - return Ok(None); - } - let mut bytes = vec![0; len]; - reader.read_exact(&mut bytes)?; - let last_non_zero_index = match bytes.iter().rposition(|&x| x != b' ') { - Some(index) => index + 1, - None => 0, - }; - bytes.truncate(last_non_zero_index); - - DataValue::Utf8 { - value: String::from_utf8(bytes)?, - ty: Utf8Type::Fixed(*ty_len), - unit: *unit, - } - } - LogicalType::Varchar(ty_len, unit) => { - let len = reader.read_u32::()? as usize; - if !is_projection { - reader.seek(SeekFrom::Current(len as i64))?; - return Ok(None); - } - let mut bytes = vec![0; len]; - reader.read_exact(&mut bytes)?; - - DataValue::Utf8 { - value: String::from_utf8(bytes)?, - ty: Utf8Type::Variable(*ty_len), - unit: *unit, - } - } - LogicalType::Date => { - if !is_projection { - reader.seek(SeekFrom::Current(4))?; - return Ok(None); - } - DataValue::Date32(reader.read_i32::()?) - } - LogicalType::DateTime => { - if !is_projection { - reader.seek(SeekFrom::Current(8))?; - return Ok(None); - } - DataValue::Date64(reader.read_i64::()?) - } - LogicalType::Time(precision) => { - let precision = match precision { - Some(precision) => *precision, - None => 0, - }; - if !is_projection { - reader.seek(SeekFrom::Current(4))?; - return Ok(None); - } - DataValue::Time32(reader.read_u32::()?, precision) - } - LogicalType::TimeStamp(precision, zone) => { - let precision = match precision { - Some(precision) => *precision, - None => 0, - }; - if !is_projection { - reader.seek(SeekFrom::Current(8))?; - return Ok(None); - } - DataValue::Time64(reader.read_i64::()?, precision, *zone) - } - LogicalType::Decimal(_, _) => { - if !is_projection { - reader.seek(SeekFrom::Current(16))?; - return Ok(None); - } - let mut bytes = [0u8; 16]; - reader.read_exact(&mut bytes)?; - - DataValue::Decimal(Decimal::deserialize(bytes)) - } - LogicalType::Tuple(_) => unreachable!(), - }; - Ok(Some(value)) - } - #[inline] pub fn logical_type(&self) -> LogicalType { match self {