From 5c10f8c41b45e93ba86e46bc64552dd1d149cf4b Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Fri, 4 Aug 2023 21:48:20 +0800 Subject: [PATCH 1/7] feat(hash_join): implement basic Inner/Left/Right/Full Join operations found a problem that needs to be fixed: - when the `Projection` operator executes eval_column, the reading data is abnormal due to the duplication of the ColumnIdx subscript of the Catalog So `Binder::bind_project` need to support join --- Cargo.toml | 1 + src/binder/mod.rs | 6 +- src/binder/select.rs | 182 +++++++++++-- src/db.rs | 23 +- src/execution_v1/physical_plan/mod.rs | 3 + .../physical_plan/physical_hash_join.rs | 9 + .../physical_plan/physical_plan_builder.rs | 22 +- .../volcano_executor/hash_join.rs | 238 +++++++++++++++++ src/execution_v1/volcano_executor/insert.rs | 2 + src/execution_v1/volcano_executor/mod.rs | 14 +- .../volcano_executor/projection.rs | 27 +- src/planner/operator/join.rs | 24 +- src/util/hash_utils.rs | 250 ++++++++++++++++++ src/util/mod.rs | 3 +- 14 files changed, 751 insertions(+), 53 deletions(-) create mode 100644 src/execution_v1/physical_plan/physical_hash_join.rs create mode 100644 src/execution_v1/volcano_executor/hash_join.rs create mode 100644 src/util/hash_utils.rs diff --git a/Cargo.toml b/Cargo.toml index 861ab756..64bf918e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ async-channel = "1.8.0" async-backtrace = "0.2.6" futures = "0.3.25" futures-lite = "1.12.0" +ahash = "0.8.3" [dev-dependencies] tokio-test = "0.4.2" diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 52037106..9b3f7ea4 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -4,7 +4,7 @@ pub mod expr; mod select; mod insert; -use std::collections::HashMap; +use std::collections::BTreeMap; use anyhow::Result; use sqlparser::ast::{Ident, ObjectName, SetExpr, Statement}; @@ -16,8 +16,8 @@ use crate::types::TableIdx; #[derive(Clone)] pub struct BinderContext { catalog: RootCatalog, - bind_table: HashMap, - aliases: HashMap, + bind_table: BTreeMap, + aliases: BTreeMap, group_by_exprs: Vec, agg_calls: Vec, index: u16, diff --git a/src/binder/select.rs b/src/binder/select.rs index d1e11f54..34677dd1 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -15,15 +15,19 @@ use crate::{ use super::Binder; -use crate::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME}; +use crate::catalog::{DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, TableCatalog}; use anyhow::Result; use itertools::Itertools; +use sqlparser::ast; use sqlparser::ast::{ Expr, Ident, Join, JoinConstraint, JoinOperator, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins, }; +use crate::expression::BinaryOperator; use crate::planner::LogicalPlan; +use crate::planner::operator::join::JoinCondition; use crate::planner::operator::sort::{SortField, SortOperator}; +use crate::types::{LogicalType, TableIdx}; impl Binder { pub(crate) fn bind_query(&mut self, query: &Query) -> Result { @@ -110,18 +114,18 @@ impl Binder { let TableWithJoins { relation, joins } = &from[0]; - let mut plan = self.bind_single_table_ref(relation)?; + let (left_id, mut plan) = self.bind_single_table_ref(relation)?; if !joins.is_empty() { for join in joins { - plan = self.bind_join(plan, join)?; + plan = self.bind_join(left_id, plan, join)?; } } Ok(plan) } - fn bind_single_table_ref(&mut self, table: &TableFactor) -> Result { - let plan = match table { + fn bind_single_table_ref(&mut self, table: &TableFactor) -> Result<(TableIdx, LogicalPlan)> { + let plan_with_id = match table { TableFactor::Table { name, alias, .. } => { let obj_name = name .0 @@ -155,12 +159,12 @@ impl Binder { self.context.bind_table.insert(table.into(), table_ref_id); - ScanOperator::new(table_ref_id) + (table_ref_id, ScanOperator::new(table_ref_id)) } _ => unimplemented!(), }; - Ok(plan) + Ok(plan_with_id) } /// Normalize select item. @@ -216,34 +220,40 @@ impl Binder { Ok(exprs) } - fn bind_join(&mut self, left: LogicalPlan, join: &Join) -> Result { + fn bind_join(&mut self, left_id: TableIdx, left: LogicalPlan, join: &Join) -> Result { let Join { relation, join_operator, } = join; - let right = self.bind_single_table_ref(relation)?; + let (right_id, right) = self.bind_single_table_ref(relation)?; let join_type = match join_operator { JoinOperator::Inner(constraint) => (JoinType::Inner, Some(constraint)), - JoinOperator::LeftOuter(constraint) => (JoinType::LeftOuter, Some(constraint)), - JoinOperator::RightOuter(constraint) => (JoinType::RightOuter, Some(constraint)), - JoinOperator::FullOuter(constraint) => (JoinType::FullOuter, Some(constraint)), + JoinOperator::LeftOuter(constraint) => (JoinType::Left, Some(constraint)), + JoinOperator::RightOuter(constraint) => (JoinType::Right, Some(constraint)), + JoinOperator::FullOuter(constraint) => (JoinType::Full, Some(constraint)), JoinOperator::CrossJoin => (JoinType::Cross, None), - JoinOperator::LeftSemi(constraint) => (JoinType::LeftSemi, Some(constraint)), - JoinOperator::RightSemi(constraint) => (JoinType::RightSemi, Some(constraint)), - JoinOperator::LeftAnti(constraint) => (JoinType::LeftAnti, Some(constraint)), - JoinOperator::RightAnti(constraint) => (JoinType::RightAnti, Some(constraint)), _ => unimplemented!(), }; + let left_table = self.context.catalog + .get_table(left_id) + .cloned() + .expect("Left table not found"); + let right_table = self.context.catalog + .get_table(right_id) + .cloned() + .expect("Right table not found"); let on = match join_type.1 { - Some(constraint) => match constraint { - JoinConstraint::On(expr) => Some(self.bind_expr(expr)?), - _ => unimplemented!(), - }, - None => None, + Some(constraint) => self.bind_join_constraint( + &left_table, + &right_table, + constraint + )?, + None => JoinCondition::None, }; + Ok(LJoinOperator::new(left, right, on, join_type.0)) } @@ -331,6 +341,119 @@ impl Binder { Ok(LimitOperator::new(offset, limit, children)) } + + fn bind_join_constraint( + &mut self, + left_table: &TableCatalog, + right_table: &TableCatalog, + constraint: &JoinConstraint, + ) -> Result { + match constraint { + JoinConstraint::On(expr) => { + // left and right columns that match equi-join pattern + let mut on_keys: Vec<(ScalarExpression, ScalarExpression)> = vec![]; + // expression that didn't match equi-join pattern + let mut filter = vec![]; + + self.extract_join_keys(expr, &mut on_keys, &mut filter, left_table, right_table)?; + + // combine multiple filter exprs into one BinaryExpr + let join_filter = filter + .into_iter() + .reduce(|acc, expr| ScalarExpression::Binary { + op: BinaryOperator::And, + left_expr: Box::new(acc), + right_expr: Box::new(expr), + ty: LogicalType::Boolean, + }); + // TODO: handle cross join if on_keys is empty + Ok(JoinCondition::On { + on: on_keys, + filter: join_filter, + }) + } + _ => unimplemented!("not supported join constraint {:?}", constraint), + } + } + + /// for sqlrs + /// original idea from datafusion planner.rs + /// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs + /// Filters matching this pattern are added to `accum` + /// Filters that don't match this pattern are added to `accum_filter` + /// Examples: + /// ```text + /// foo = bar => accum=[(foo, bar)] accum_filter=[] + /// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] + /// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] + /// ``` + fn extract_join_keys( + &mut self, + expr: &Expr, + accum: &mut Vec<(ScalarExpression, ScalarExpression)>, + accum_filter: &mut Vec, + left_schema: &TableCatalog, + right_schema: &TableCatalog, + ) -> Result<()> { + match expr { + Expr::BinaryOp { left, op, right } => match op { + ast::BinaryOperator::Eq => { + let left = self.bind_expr(left)?; + let right = self.bind_expr(right)?; + + match (&left, &right) { + // example: foo = bar + (ScalarExpression::ColumnRef(l), ScalarExpression::ColumnRef(r)) => { + // reorder left and right joins keys to pattern: (left, right) + if left_schema.contains_column(&l.name) + && right_schema.contains_column(&r.name) + { + accum.push((left, right)); + } else if left_schema.contains_column(&r.name) + && right_schema.contains_column(&l.name) + { + accum.push((right, left)); + } else { + accum_filter.push(self.bind_expr(expr)?); + } + } + // example: baz = 1 + _other => { + accum_filter.push(self.bind_expr(expr)?); + } + } + } + ast::BinaryOperator::And => { + // example: foo = bar AND baz > 1 + if let Expr::BinaryOp { left, op: _, right } = expr { + self.extract_join_keys( + left, + accum, + accum_filter, + left_schema, + right_schema, + )?; + self.extract_join_keys( + right, + accum, + accum_filter, + left_schema, + right_schema, + )?; + } + } + _other => { + // example: baz > 1 + accum_filter.push(self.bind_expr(expr)?); + } + }, + _other => { + // example: baz in (xxx), something else will convert to filter logic + accum_filter.push(self.bind_expr(expr)?); + } + } + Ok(()) + } } #[cfg(test)] @@ -343,11 +466,18 @@ mod tests { fn test_root_catalog() -> Result { let mut root = RootCatalog::new(); - let cols = vec![ + + let cols_t1 = vec![ ColumnCatalog::new("c1".to_string(), false, ColumnDesc::new(Integer, true)), ColumnCatalog::new("c2".to_string(), false, ColumnDesc::new(Integer, false)), ]; - let _ = root.add_table("t1".to_string(), cols)?; + let _ = root.add_table("t1".to_string(), cols_t1)?; + + let cols_t2 = vec![ + ColumnCatalog::new("c3".to_string(), false, ColumnDesc::new(Integer, true)), + ColumnCatalog::new("c4".to_string(), false, ColumnDesc::new(Integer, false)), + ]; + let _ = root.add_table("t2".to_string(), cols_t2)?; Ok(root) } @@ -405,6 +535,12 @@ mod tests { plan_8 ); + let plan_9 = select_sql_run("select * from t1 inner join t2 on c1 = c3 and c1 > 1")?; + println!( + "join:\n {:#?}", + plan_9 + ); + Ok(()) } } \ No newline at end of file diff --git a/src/db.rs b/src/db.rs index fd4ef8b5..34bd9641 100644 --- a/src/db.rs +++ b/src/db.rs @@ -150,10 +150,11 @@ mod test { tokio_test::block_on(async move { let _ = kipsql.run("create table t1 (a int, b int)").await?; - let _ = kipsql.run("insert into t1 values (1, 1), (2, 3), (5, 4)").await?; + let _ = kipsql.run("create table t2 (c int, d int)").await?; + let _ = kipsql.run("insert into t1 values (1, 1), (3, 3), (5, 4)").await?; + let _ = kipsql.run("insert into t2 values (1, 2), (2, 3), (5, 6)").await?; - println!("full:"); - let vec_batch_full_fields = kipsql.run("select * from t1").await?; + let vec_batch_full_fields = kipsql.run("select * from t1 right join t2 on a = c").await?; print_batches(&vec_batch_full_fields)?; println!("projection_and_filter:"); @@ -168,6 +169,22 @@ mod test { let vec_batch_limit=kipsql.run("select * from t1 limit 1 offset 1").await?; print_batches(&vec_batch_limit)?; + println!("inner join:"); + let vec_batch_inner_join = kipsql.run("select * from t1 inner join t2 on a = c").await?; + print_batches(&vec_batch_inner_join)?; + + println!("left join:"); + let vec_batch_left_join = kipsql.run("select * from t1 left join t2 on a = c").await?; + print_batches(&vec_batch_left_join)?; + + println!("right join:"); + let vec_batch_right_join = kipsql.run("select * from t1 right join t2 on a = c").await?; + print_batches(&vec_batch_right_join)?; + + println!("full join:"); + let vec_batch_full_join = kipsql.run("select * from t1 full join t2 on a = c").await?; + print_batches(&vec_batch_full_join)?; + Ok(()) }) } diff --git a/src/execution_v1/physical_plan/mod.rs b/src/execution_v1/physical_plan/mod.rs index 9b6dd24f..9bb2c2fb 100644 --- a/src/execution_v1/physical_plan/mod.rs +++ b/src/execution_v1/physical_plan/mod.rs @@ -1,5 +1,6 @@ use crate::execution_v1::physical_plan::physical_create_table::PhysicalCreateTable; use crate::execution_v1::physical_plan::physical_filter::PhysicalFilter; +use crate::execution_v1::physical_plan::physical_hash_join::PhysicalHashJoin; use crate::execution_v1::physical_plan::physical_insert::PhysicalInsert; use crate::execution_v1::physical_plan::physical_limit::PhysicalLimit; use crate::execution_v1::physical_plan::physical_projection::PhysicalProjection; @@ -16,6 +17,7 @@ pub(crate) mod physical_values; pub(crate) mod physical_filter; pub(crate) mod physical_sort; pub(crate) mod physical_limit; +pub(crate) mod physical_hash_join; #[derive(Debug)] pub enum PhysicalPlan { @@ -27,4 +29,5 @@ pub enum PhysicalPlan { Sort(PhysicalSort), Values(PhysicalValues), Limit(PhysicalLimit), + HashJoin(PhysicalHashJoin), } \ No newline at end of file diff --git a/src/execution_v1/physical_plan/physical_hash_join.rs b/src/execution_v1/physical_plan/physical_hash_join.rs new file mode 100644 index 00000000..c887cb79 --- /dev/null +++ b/src/execution_v1/physical_plan/physical_hash_join.rs @@ -0,0 +1,9 @@ +use crate::execution_v1::physical_plan::PhysicalPlan; +use crate::planner::operator::join::JoinOperator; + +#[derive(Debug)] +pub struct PhysicalHashJoin { + pub(crate) op: JoinOperator, + pub(crate) left_input: Box, + pub(crate) right_input: Box +} \ No newline at end of file diff --git a/src/execution_v1/physical_plan/physical_plan_builder.rs b/src/execution_v1/physical_plan/physical_plan_builder.rs index 371f5891..b780b201 100644 --- a/src/execution_v1/physical_plan/physical_plan_builder.rs +++ b/src/execution_v1/physical_plan/physical_plan_builder.rs @@ -8,6 +8,7 @@ use crate::planner::LogicalPlan; use anyhow::anyhow; use anyhow::Result; use crate::execution_v1::physical_plan::physical_filter::PhysicalFilter; +use crate::execution_v1::physical_plan::physical_hash_join::PhysicalHashJoin; use crate::execution_v1::physical_plan::physical_insert::PhysicalInsert; use crate::execution_v1::physical_plan::physical_limit::PhysicalLimit; use crate::execution_v1::physical_plan::physical_sort::PhysicalSort; @@ -15,6 +16,7 @@ use crate::execution_v1::physical_plan::physical_values::PhysicalValues; use crate::planner::operator::create_table::CreateTableOperator; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::insert::InsertOperator; +use crate::planner::operator::join::{JoinOperator, JoinType}; use crate::planner::operator::limit::LimitOperator; use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::sort::SortOperator; @@ -36,7 +38,8 @@ impl PhysicalPlanBuilder { Operator::Insert(op) => self.build_insert(plan, op), Operator::Values(op) => Ok(Self::build_values(op)), Operator::Sort(op) => self.build_physical_sort(plan, op), - Operator::Limit(op)=>self.build_physical_limit(plan, op), + Operator::Limit(op) => self.build_physical_limit(plan, op), + Operator::Join(op) => self.build_physical_join(plan, op), _ => Err(anyhow!(format!( "Unsupported physical plan: {:?}", plan.operator @@ -99,7 +102,7 @@ impl PhysicalPlanBuilder { })) } - fn build_physical_limit(&mut self, plan: &LogicalPlan, base : &LimitOperator)->Result{ + fn build_physical_limit(&mut self, plan: &LogicalPlan, base: &LimitOperator) -> Result { let input =self.build_plan(plan.child(0)?)?; Ok(PhysicalPlan::Limit(PhysicalLimit{ @@ -107,4 +110,19 @@ impl PhysicalPlanBuilder { input: Box::new(input), })) } + + fn build_physical_join(&mut self, plan: &LogicalPlan, base: &JoinOperator) -> Result { + let left_input = Box::new(self.build_plan(plan.child(0)?)?); + let right_input = Box::new(self.build_plan(plan.child(1)?)?); + + if base.join_type == JoinType::Cross { + todo!() + } else { + Ok(PhysicalPlan::HashJoin(PhysicalHashJoin { + op: base.clone(), + left_input, + right_input, + })) + } + } } \ No newline at end of file diff --git a/src/execution_v1/volcano_executor/hash_join.rs b/src/execution_v1/volcano_executor/hash_join.rs new file mode 100644 index 00000000..2a839536 --- /dev/null +++ b/src/execution_v1/volcano_executor/hash_join.rs @@ -0,0 +1,238 @@ +use std::collections::BTreeSet; +use std::mem; +use std::sync::Arc; +use ahash::{HashMap, HashMapExt, RandomState}; +use arrow::array::{ArrayRef, BooleanArray, new_null_array, PrimitiveArray, UInt32Builder}; +use arrow::compute; +use arrow::datatypes::{Field, Schema, UInt32Type}; +use arrow::record_batch::RecordBatch; +use futures_async_stream::try_stream; +use itertools::Itertools; +use crate::execution_v1::ExecutorError; +use crate::execution_v1::volcano_executor::BoxedExecutor; +use crate::expression::ScalarExpression; +use crate::planner::operator::join::{JoinCondition, JoinType}; +use crate::util::hash_utils::create_hashes; + +pub struct HashJoin { } + +impl HashJoin { + #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] + pub async fn execute(on: JoinCondition, ty: JoinType, left_input: BoxedExecutor, right_input: BoxedExecutor) { + if ty == JoinType::Cross { + unreachable!("Cross join should not be in HashJoinExecutor"); + } + let ((on_left_keys, on_right_keys), filter) = match on { + JoinCondition::On { on, filter } => (on.into_iter().unzip(), filter), + JoinCondition::None => unreachable!("HashJoin must has on condition") + }; + + // build phase: + // 1.construct hashtable, one hash key may contains multiple rows indices. + // 2.merged all left batches into single batch. + let mut left_hashmap = HashMap::new(); + let mut left_row_offset = 0; + let mut left_batches = vec![]; + + let hash_random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut join_fields: Vec = Vec::new(); + // FIXME: 应该在Binder层处理,使Project能够同时获取此nullable信息(因为Join会改变主键的nullable判定) + let (left_force_nullable, right_force_nullable) = match ty { + JoinType::Inner => (false, false), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (true, true), + JoinType::Cross => (true, true), + }; + + #[for_await] + for batch in left_input { + let batch: RecordBatch = batch?; + let rows_hashes = Self::hash_columns(&on_left_keys, &hash_random_state, &batch)?; + + if join_fields.is_empty() { + Self::filling_fields(&mut join_fields, left_force_nullable, &batch); + } + + for (row, hash) in rows_hashes.iter().enumerate() { + left_hashmap + .entry(*hash) + .or_insert_with(Vec::new) + .push(row + left_row_offset); + } + + left_row_offset += batch.num_rows(); + left_batches.push(batch); + } + + let left_single_batch = if left_batches.is_empty() && (ty == JoinType::Left || ty == JoinType::Inner) { + None + } else { + Some(compute::concat_batches(&left_batches[0].schema(), &left_batches)?) + }; + + // probe phase + // + // build visited_left_side to record the left data has been visited, + // because probe phase only visit the right data, so if we use left-join or full-join, + // the left unvisited data should be returned to meet the join semantics. + let full_left_side: BTreeSet = (0..left_row_offset as u32) + .into_iter() + .collect(); + let mut visited_left_side = BTreeSet::new(); + let mut join_schema = None; + + #[for_await] + for batch in right_input { + let batch = batch?; + let rows_hashes = Self::hash_columns(&on_right_keys, &hash_random_state, &batch)?; + + // init join_schema + let schema = join_schema.get_or_insert_with(|| { + Self::filling_fields(&mut join_fields, right_force_nullable, &batch); + + Arc::new(Schema::new(mem::take(&mut join_fields))) + }); + + // 1. build left and right indices + let mut left_indices = UInt32Builder::new(); + let mut right_indices = UInt32Builder::new(); + // for sqlrs: Get the hash and find it in the build index + // TODO: For every item on the left and right we check if it matches + // This possibly contains rows with hash collisions, + // So we have to check here whether rows are equal or not + for (row, hash) in rows_hashes.iter().enumerate() { + if let Some(indices) = left_hashmap.get(hash) { + for &i in indices { + left_indices.append_value(i as u32); + right_indices.append_value(row as u32); + } + } else if ty == JoinType::Right || ty == JoinType::Full { + // when no match, add the row with None for the left side + left_indices.append_null(); + right_indices.append_value(row as u32); + } + } + + // 2. build intermediate batch that from left and right all columns + let mut left_indices = left_indices.finish(); + let mut right_indices = right_indices.finish(); + + let mut intermediate_batch = Self::build_batch( + &left_single_batch, + &batch, + schema, + &left_indices, + &right_indices + )?; + + if let Some(ref expr) = filter { + let predicate = expr.eval_column(&intermediate_batch)? + .as_any() + .downcast_ref::() + .cloned() + .expect("join filter expected evaluate boolean array"); + left_indices = PrimitiveArray::::from( + compute::filter(&left_indices, &predicate)?.data().clone(), + ); + if ty == JoinType::Right || ty == JoinType::Full { + right_indices = PrimitiveArray::::from( + compute::filter(&right_indices, &predicate)?.data().clone(), + ) + }; + + intermediate_batch = Self::build_batch( + &left_single_batch, + &batch, + schema, + &left_indices, + &right_indices + )?; + } + + if ty == JoinType::Left || ty == JoinType::Full { + left_indices + .iter() + .flatten() + .for_each(|i| { + let _ = visited_left_side.insert(i); + }); + } + yield intermediate_batch; + } + + if let Some(left_batch) = left_single_batch { + if ty == JoinType::Left || ty == JoinType::Full { + let join_schema = join_schema.unwrap(); + + let indices: PrimitiveArray = PrimitiveArray::from_iter_values( + full_left_side + .symmetric_difference(&visited_left_side) + .cloned() + ); + + let mut arrays: Vec = left_batch + .columns() + .iter() + .map(|col| compute::take(col, &indices, None)) + .try_collect()?; + let offset = arrays.len(); + for field in join_schema.fields()[offset..].iter() { + arrays.push(new_null_array(field.data_type(), indices.len())); + } + + yield RecordBatch::try_new(join_schema, arrays)?; + } + } + } + + fn filling_fields(join_fields: &mut Vec, force_nullable: bool, batch: &RecordBatch) { + let mut fields = batch.schema().fields() + .into_iter() + .map(|field| field.clone().with_nullable(force_nullable)) + .collect_vec(); + + join_fields.append(&mut fields); + } + + fn build_batch( + left_single_batch: &Option, + right_batch: &RecordBatch, + schema: &mut Arc, + left_indices: &PrimitiveArray, + right_indices: &PrimitiveArray + ) -> Result { + let full_arrays = if let Some(left_batch) = left_single_batch { + Self::select_with_indices(left_batch, &left_indices)? + } else { + vec![] + }.into_iter() + .chain(Self::select_with_indices(right_batch, &right_indices)?) + .collect_vec(); + Ok(RecordBatch::try_new(schema.clone(), full_arrays)?) + } + + fn select_with_indices(batch: &RecordBatch, indices: &PrimitiveArray) -> Result, ExecutorError> { + Ok(batch + .columns() + .iter() + .map(|col| compute::take(col, &indices, None)) + .try_collect()?) + } + + fn hash_columns( + col_keys: &Vec, + hash_random_state: &RandomState, + batch: &RecordBatch + ) -> Result, ExecutorError> { + let arrays: Vec = col_keys + .iter() + .map(|expr| expr.eval_column(&batch)) + .try_collect()?; + + let mut every_rows_hashes = vec![0; batch.num_rows()]; + create_hashes(&arrays, &hash_random_state, &mut every_rows_hashes)?; + + Ok(every_rows_hashes) + } +} \ No newline at end of file diff --git a/src/execution_v1/volcano_executor/insert.rs b/src/execution_v1/volcano_executor/insert.rs index 3e648f89..cf556024 100644 --- a/src/execution_v1/volcano_executor/insert.rs +++ b/src/execution_v1/volcano_executor/insert.rs @@ -23,6 +23,8 @@ impl Insert { let mut arrays = batch.columns().to_vec(); let col_len = arrays[0].len(); + arrays.reverse(); + let full_arrays = table.all_columns() .into_iter() .map(|(_, col_catalog)| { diff --git a/src/execution_v1/volcano_executor/mod.rs b/src/execution_v1/volcano_executor/mod.rs index 08570442..93414bbe 100644 --- a/src/execution_v1/volcano_executor/mod.rs +++ b/src/execution_v1/volcano_executor/mod.rs @@ -6,6 +6,7 @@ mod values; mod filter; mod sort; mod limit; +mod hash_join; use crate::execution_v1::physical_plan::physical_projection::PhysicalProjection; use crate::execution_v1::physical_plan::PhysicalPlan; @@ -18,14 +19,17 @@ use arrow::record_batch::RecordBatch; use futures::stream::BoxStream; use futures::TryStreamExt; use crate::execution_v1::physical_plan::physical_filter::PhysicalFilter; +use crate::execution_v1::physical_plan::physical_hash_join::PhysicalHashJoin; use crate::execution_v1::physical_plan::physical_insert::PhysicalInsert; use crate::execution_v1::physical_plan::physical_limit::PhysicalLimit; use crate::execution_v1::physical_plan::physical_sort::PhysicalSort; use crate::execution_v1::volcano_executor::filter::Filter; +use crate::execution_v1::volcano_executor::hash_join::HashJoin; use crate::execution_v1::volcano_executor::insert::Insert; use crate::execution_v1::volcano_executor::limit::Limit; use crate::execution_v1::volcano_executor::sort::Sort; use crate::execution_v1::volcano_executor::values::Values; +use crate::planner::operator::join::JoinOperator; pub type BoxedExecutor = BoxStream<'static, Result>; @@ -72,11 +76,19 @@ impl VolcanoExecutor { Sort::execute(op.sort_fields, op.limit, input) } - PhysicalPlan::Limit(PhysicalLimit {op,input, ..}) =>{ + PhysicalPlan::Limit(PhysicalLimit {op,input, ..}) => { let input = self.build(*input); Limit::execute(Some(op.offset), Some(op.limit), input) } + PhysicalPlan::HashJoin(PhysicalHashJoin { op, left_input, right_input}) => { + let left_input = self.build(*left_input); + let right_input = self.build(*right_input); + + let JoinOperator { on, join_type } = op; + + HashJoin::execute(on, join_type, left_input, right_input) + } } } diff --git a/src/execution_v1/volcano_executor/projection.rs b/src/execution_v1/volcano_executor/projection.rs index aac6fb7b..a2256387 100644 --- a/src/execution_v1/volcano_executor/projection.rs +++ b/src/execution_v1/volcano_executor/projection.rs @@ -10,19 +10,24 @@ pub struct Projection { } impl Projection { #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] pub async fn execute(exprs: Vec, input: BoxedExecutor) { + // FIXME: 支持JOIN投射 + // #[for_await] + // for batch in input { + // let batch = batch?; + // let columns = exprs + // .iter() + // .map(|e| e.eval_column(&batch)) + // .try_collect(); + // let fields = exprs.iter().map(|e| e.eval_field(&batch)).collect(); + // let schema = SchemaRef::new(Schema::new_with_metadata( + // fields, + // batch.schema().metadata().clone(), + // )); + // yield RecordBatch::try_new(schema, columns?)?; + // } #[for_await] for batch in input { - let batch = batch?; - let columns = exprs - .iter() - .map(|e| e.eval_column(&batch)) - .try_collect(); - let fields = exprs.iter().map(|e| e.eval_field(&batch)).collect(); - let schema = SchemaRef::new(Schema::new_with_metadata( - fields, - batch.schema().metadata().clone(), - )); - yield RecordBatch::try_new(schema, columns?)?; + yield batch?; } } } \ No newline at end of file diff --git a/src/planner/operator/join.rs b/src/planner/operator/join.rs index c1c9e8ac..01a8ca8b 100644 --- a/src/planner/operator/join.rs +++ b/src/planner/operator/join.rs @@ -6,19 +6,25 @@ use super::Operator; #[derive(Debug, PartialEq, Clone)] pub enum JoinType { Inner, - LeftOuter, - RightOuter, - FullOuter, + Left, + Right, + Full, Cross, - LeftSemi, - RightSemi, - LeftAnti, - RightAnti, +} +#[derive(Debug, Clone, PartialEq)] +pub enum JoinCondition { + On { + /// Equijoin clause expressed as pairs of (left, right) join columns + on: Vec<(ScalarExpression, ScalarExpression)>, + /// Filters applied during join (non-equi conditions) + filter: Option, + }, + None, } #[derive(Debug, PartialEq, Clone)] pub struct JoinOperator { - pub on: Option, + pub on: JoinCondition, pub join_type: JoinType, } @@ -26,7 +32,7 @@ impl JoinOperator { pub fn new( left: LogicalPlan, right: LogicalPlan, - on: Option, + on: JoinCondition, join_type: JoinType, ) -> LogicalPlan { LogicalPlan { diff --git a/src/util/hash_utils.rs b/src/util/hash_utils.rs new file mode 100644 index 00000000..ce7dffae --- /dev/null +++ b/src/util/hash_utils.rs @@ -0,0 +1,250 @@ +// copied from datafusion and deleted unused functions + +use ahash::RandomState; +use arrow::array::{ + Array, ArrayRef, BooleanArray, Float64Array, Int32Array, Int64Array, StringArray, +}; +use arrow::datatypes::DataType; + +use crate::execution_v1::ExecutorError; + +// Combines two hashes into one hash +#[inline] +fn combine_hashes(l: u64, r: u64) -> u64 { + let hash = (17 * 37u64).wrapping_add(l); + hash.wrapping_mul(37).wrapping_add(r) +} + +fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { + if mul_col { + hashes_buffer.iter_mut().for_each(|hash| { + // stable hash for null value + *hash = combine_hashes(random_state.hash_one(&1), *hash); + }) + } else { + hashes_buffer.iter_mut().for_each(|hash| { + *hash = random_state.hash_one(&1); + }) + } +} + +macro_rules! hash_array { + ( + $array_type:ident, + $column:ident, + $ty:ty, + $hashes:ident, + $random_state:ident, + $multi_col:ident + ) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + if $multi_col { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = combine_hashes($random_state.hash_one(&array.value(i)), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $random_state.hash_one(&array.value(i)); + } + } + } else { + if $multi_col { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = combine_hashes($random_state.hash_one(&array.value(i)), *hash); + } + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $random_state.hash_one(&array.value(i)); + } + } + } + } + }; +} + +macro_rules! hash_array_primitive { + ( + $array_type:ident, + $column:ident, + $ty:ident, + $hashes:ident, + $random_state:ident, + $multi_col:ident + ) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + if $multi_col { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = combine_hashes($random_state.hash_one(value), *hash); + } + } else { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $random_state.hash_one(value) + } + } + } else { + if $multi_col { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + *hash = combine_hashes($random_state.hash_one(value), *hash); + } + } + } else { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + *hash = $random_state.hash_one(value); + } + } + } + } + }; +} + +macro_rules! hash_array_float { + ( + $array_type:ident, + $column:ident, + $ty:ident, + $hashes:ident, + $random_state:ident, + $multi_col:ident + ) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + if $multi_col { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = combine_hashes( + $random_state.hash_one(&$ty::from_le_bytes(value.to_le_bytes())), + *hash, + ); + } + } else { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $random_state.hash_one(&$ty::from_le_bytes(value.to_le_bytes())) + } + } + } else { + if $multi_col { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + *hash = combine_hashes( + $random_state.hash_one(&$ty::from_le_bytes(value.to_le_bytes())), + *hash, + ); + } + } + } else { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + *hash = $random_state.hash_one(&$ty::from_le_bytes(value.to_le_bytes())); + } + } + } + } + }; +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +#[cfg(not(feature = "force_hash_collisions"))] +pub fn create_hashes<'a>( + arrays: &[ArrayRef], + random_state: &RandomState, + hashes_buffer: &'a mut Vec, +) -> Result<&'a mut Vec, ExecutorError> { + // combine hashes with `combine_hashes` if we have more than 1 column + let multi_col = arrays.len() > 1; + + for col in arrays { + match col.data_type() { + DataType::Null => { + hash_null(random_state, hashes_buffer, multi_col); + } + DataType::Int32 => { + hash_array_primitive!(Int32Array, col, i32, hashes_buffer, random_state, multi_col); + } + DataType::Int64 => { + hash_array_primitive!(Int64Array, col, i64, hashes_buffer, random_state, multi_col); + } + DataType::Float64 => { + hash_array_float!( + Float64Array, + col, + u64, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Boolean => { + hash_array!( + BooleanArray, + col, + u8, + hashes_buffer, + random_state, + multi_col + ); + } + DataType::Utf8 => { + hash_array!( + StringArray, + col, + str, + hashes_buffer, + random_state, + multi_col + ); + } + _ => { + // This is internal because we should have caught this before. + return Err(ExecutorError::InternalError(format!( + "Unsupported data type in hasher: {}", + col.data_type() + ))); + } + } + } + Ok(hashes_buffer) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + #[test] + fn create_hashes_for_float_arrays() -> Result<(), ExecutorError> { + let f64_arr = Arc::new(Float64Array::from_iter_values(vec![0.12, 0.5, 1f64, 444.7])); + let f64_arr_2 = Arc::new(Float64Array::from_iter_values(vec![0.12, 0.5, 1f64, 444.7])); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let hashes_buff = &mut vec![0; f64_arr.len()]; + + let hashes = create_hashes(&[f64_arr, f64_arr_2], &random_state, hashes_buff)?; + assert_eq!(hashes.len(), 4); + assert_eq!(hashes.clone(), hashes_buff.clone()); + assert_eq!( + hashes_buff, + &[ + 13192744372685867462, + 5527281222425499956, + 3851526787237496334, + 1092489821776418240, + ] + ); + Ok(()) + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index 698484a0..adef7d27 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,6 +1,7 @@ +pub mod hash_utils; use arrow::record_batch::RecordBatch; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::DataType; use arrow::error::ArrowError; use arrow::util::display::array_value_to_string; From fb7fffea12dda34aa8c8570d4acfda81ff6c8094 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Sat, 5 Aug 2023 14:39:49 +0800 Subject: [PATCH 2/7] fix(hash_join): fix filter on right join --- src/db.rs | 9 +- .../volcano_executor/hash_join.rs | 124 +++++++++++------- 2 files changed, 82 insertions(+), 51 deletions(-) diff --git a/src/db.rs b/src/db.rs index 34bd9641..e43e0dc1 100644 --- a/src/db.rs +++ b/src/db.rs @@ -154,6 +154,7 @@ mod test { let _ = kipsql.run("insert into t1 values (1, 1), (3, 3), (5, 4)").await?; let _ = kipsql.run("insert into t2 values (1, 2), (2, 3), (5, 6)").await?; + println!("full:"); let vec_batch_full_fields = kipsql.run("select * from t1 right join t2 on a = c").await?; print_batches(&vec_batch_full_fields)?; @@ -170,19 +171,19 @@ mod test { print_batches(&vec_batch_limit)?; println!("inner join:"); - let vec_batch_inner_join = kipsql.run("select * from t1 inner join t2 on a = c").await?; + let vec_batch_inner_join = kipsql.run("select * from t1 inner join t2 on a = c and c > 1").await?; print_batches(&vec_batch_inner_join)?; println!("left join:"); - let vec_batch_left_join = kipsql.run("select * from t1 left join t2 on a = c").await?; + let vec_batch_left_join = kipsql.run("select * from t1 left join t2 on a = c and c > 1").await?; print_batches(&vec_batch_left_join)?; println!("right join:"); - let vec_batch_right_join = kipsql.run("select * from t1 right join t2 on a = c").await?; + let vec_batch_right_join = kipsql.run("select * from t1 right join t2 on a = c and c > 1").await?; print_batches(&vec_batch_right_join)?; println!("full join:"); - let vec_batch_full_join = kipsql.run("select * from t1 full join t2 on a = c").await?; + let vec_batch_full_join = kipsql.run("select * from t1 full join t2 on a = c and c > 1").await?; print_batches(&vec_batch_full_join)?; Ok(()) diff --git a/src/execution_v1/volcano_executor/hash_join.rs b/src/execution_v1/volcano_executor/hash_join.rs index 2a839536..d3cd7a70 100644 --- a/src/execution_v1/volcano_executor/hash_join.rs +++ b/src/execution_v1/volcano_executor/hash_join.rs @@ -65,10 +65,10 @@ impl HashJoin { left_batches.push(batch); } - let left_single_batch = if left_batches.is_empty() && (ty == JoinType::Left || ty == JoinType::Inner) { - None - } else { + let left_single_batch = if !left_batches.is_empty() { Some(compute::concat_batches(&left_batches[0].schema(), &left_batches)?) + } else { + None }; // probe phase @@ -88,11 +88,12 @@ impl HashJoin { let rows_hashes = Self::hash_columns(&on_right_keys, &hash_random_state, &batch)?; // init join_schema - let schema = join_schema.get_or_insert_with(|| { - Self::filling_fields(&mut join_fields, right_force_nullable, &batch); - - Arc::new(Schema::new(mem::take(&mut join_fields))) - }); + let schema = Self::init_schema( + &mut join_fields, + &mut join_schema, + right_force_nullable, + &batch + ); // 1. build left and right indices let mut left_indices = UInt32Builder::new(); @@ -116,7 +117,7 @@ impl HashJoin { // 2. build intermediate batch that from left and right all columns let mut left_indices = left_indices.finish(); - let mut right_indices = right_indices.finish(); + let mut right_indices = right_indices.finish(); let mut intermediate_batch = Self::build_batch( &left_single_batch, @@ -127,27 +128,36 @@ impl HashJoin { )?; if let Some(ref expr) = filter { - let predicate = expr.eval_column(&intermediate_batch)? - .as_any() - .downcast_ref::() - .cloned() - .expect("join filter expected evaluate boolean array"); - left_indices = PrimitiveArray::::from( - compute::filter(&left_indices, &predicate)?.data().clone(), - ); - if ty == JoinType::Right || ty == JoinType::Full { - right_indices = PrimitiveArray::::from( - compute::filter(&right_indices, &predicate)?.data().clone(), - ) - }; - - intermediate_batch = Self::build_batch( - &left_single_batch, - &batch, - schema, - &left_indices, - &right_indices - )?; + if !(ty == JoinType::Full || ty == JoinType::Cross) { + let predicate = expr.eval_column(&intermediate_batch)? + .as_any() + .downcast_ref::() + .cloned() + .expect("join filter expected evaluate boolean array"); + left_indices = PrimitiveArray::::from( + compute::filter(&left_indices, &predicate)?.data().clone(), + ); + if ty == JoinType::Right { + let abs = left_indices.len().abs_diff(right_indices.len()); + if abs > 0 { + left_indices = left_indices.into_iter() + .chain((0..abs).map(|_| None)) + .collect(); + } + } else { + right_indices = PrimitiveArray::::from( + compute::filter(&right_indices, &predicate)?.data().clone(), + ); + } + + intermediate_batch = Self::build_batch( + &left_single_batch, + &batch, + schema, + &left_indices, + &right_indices + )?; + } } if ty == JoinType::Left || ty == JoinType::Full { @@ -162,30 +172,50 @@ impl HashJoin { } if let Some(left_batch) = left_single_batch { - if ty == JoinType::Left || ty == JoinType::Full { - let join_schema = join_schema.unwrap(); + if !(ty == JoinType::Left || ty == JoinType::Full) { + return Ok(()); + } - let indices: PrimitiveArray = PrimitiveArray::from_iter_values( - full_left_side - .symmetric_difference(&visited_left_side) - .cloned() - ); + let schema = Self::init_schema( + &mut join_fields, + &mut join_schema, + left_force_nullable, + &left_batch + ).clone(); - let mut arrays: Vec = left_batch - .columns() - .iter() - .map(|col| compute::take(col, &indices, None)) - .try_collect()?; - let offset = arrays.len(); - for field in join_schema.fields()[offset..].iter() { - arrays.push(new_null_array(field.data_type(), indices.len())); - } + let indices: PrimitiveArray = PrimitiveArray::from_iter_values( + full_left_side + .symmetric_difference(&visited_left_side) + .cloned() + ); - yield RecordBatch::try_new(join_schema, arrays)?; + let mut arrays: Vec = left_batch + .columns() + .iter() + .map(|col| compute::take(col, &indices, None)) + .try_collect()?; + let offset = arrays.len(); + for field in schema.fields()[offset..].iter() { + arrays.push(new_null_array(field.data_type(), indices.len())); } + + yield RecordBatch::try_new(schema, arrays)?; } } + fn init_schema<'a>( + join_fields: &mut Vec, + join_schema: &'a mut Option>, + force_nullable: bool, + batch: &RecordBatch + ) -> &'a mut Arc { + join_schema.get_or_insert_with(|| { + Self::filling_fields(join_fields, force_nullable, &batch); + + Arc::new(Schema::new(mem::take(join_fields))) + }) + } + fn filling_fields(join_fields: &mut Vec, force_nullable: bool, batch: &RecordBatch) { let mut fields = batch.schema().fields() .into_iter() From 9170a368a4355bb43bf0a80012eb769470faca73 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Sun, 6 Aug 2023 01:17:26 +0800 Subject: [PATCH 3/7] feat(projection): `Projection` supports multi-table mapping By marshaling Column and rewriting `ScalarExpression` to convert `ColumnRef` to `InputRef` --- src/binder/create_table.rs | 2 +- src/binder/expr.rs | 2 +- src/binder/mod.rs | 12 +- src/binder/select.rs | 29 ++-- src/catalog/column.rs | 4 +- src/catalog/mod.rs | 28 ---- src/catalog/root.rs | 24 +-- src/catalog/table.rs | 37 ++--- src/db.rs | 18 +-- .../physical_plan/physical_plan_builder.rs | 150 ++++++++++++++++-- src/execution_v1/volcano_executor/insert.rs | 2 +- .../volcano_executor/projection.rs | 27 ++-- src/execution_v1/volcano_executor/sort.rs | 1 + .../volcano_executor/table_scan.rs | 4 +- src/expression/evaluator.rs | 7 +- src/planner/operator/join.rs | 2 +- src/planner/operator/scan.rs | 12 +- src/storage/memory.rs | 17 +- src/storage/mod.rs | 8 +- src/types/mod.rs | 61 +++---- 20 files changed, 266 insertions(+), 181 deletions(-) diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 4ce552e6..381a8b23 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -61,7 +61,7 @@ mod tests { let sql = "create table t1 (id int , name varchar(10))"; let binder = Binder::new(BinderContext::new(RootCatalog::new())); let stmt = crate::parser::parse_sql(sql).unwrap(); - let plan1 = binder.bind(&stmt[0]).unwrap(); + let (plan1, _) = binder.bind(&stmt[0]).unwrap(); let plan2 = LogicalPlan { operator: Operator::CreateTable( diff --git a/src/binder/expr.rs b/src/binder/expr.rs index fe333f95..50cb6c65 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -68,7 +68,7 @@ impl Binder { } else { // handle col syntax let mut got_column = None; - for table_catalog in &self.context.catalog.tables { + for (_, table_catalog) in self.context.catalog.tables() { if let Some(column_catalog) = table_catalog.get_column_by_name(column_name) { if got_column.is_some() { return Err(BindError::InvalidColumn(column_name.to_string()).into()); diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 9b3f7ea4..4c581737 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -12,11 +12,11 @@ use sqlparser::ast::{Ident, ObjectName, SetExpr, Statement}; use crate::catalog::{RootCatalog, DEFAULT_SCHEMA_NAME, CatalogError}; use crate::expression::ScalarExpression; use crate::planner::LogicalPlan; -use crate::types::TableIdx; -#[derive(Clone)] +use crate::types::TableId; +#[derive(Debug, Clone)] pub struct BinderContext { - catalog: RootCatalog, - bind_table: BTreeMap, + pub(crate) catalog: RootCatalog, + pub(crate) bind_table: BTreeMap, aliases: BTreeMap, group_by_exprs: Vec, agg_calls: Vec, @@ -59,7 +59,7 @@ impl Binder { Binder { context } } - pub fn bind(mut self, stmt: &Statement) -> Result { + pub fn bind(mut self, stmt: &Statement) -> Result<(LogicalPlan, BinderContext)> { let plan = match stmt { Statement::Query(query) => self.bind_query(query)?, Statement::CreateTable { name, columns, .. } => self.bind_create_table(name, &columns)?, @@ -72,7 +72,7 @@ impl Binder { } _ => unimplemented!(), }; - Ok(plan) + Ok((plan, self.context)) } } diff --git a/src/binder/select.rs b/src/binder/select.rs index 34677dd1..11c79fb9 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -1,7 +1,6 @@ use std::borrow::Borrow; use crate::{ - catalog::ColumnRefId, expression::ScalarExpression, planner::{ operator::{ @@ -27,7 +26,7 @@ use crate::expression::BinaryOperator; use crate::planner::LogicalPlan; use crate::planner::operator::join::JoinCondition; use crate::planner::operator::sort::{SortField, SortOperator}; -use crate::types::{LogicalType, TableIdx}; +use crate::types::{LogicalType, TableId}; impl Binder { pub(crate) fn bind_query(&mut self, query: &Query) -> Result { @@ -100,6 +99,7 @@ impl Binder { } plan = self.bind_project(plan, select_list); + Ok(plan) } @@ -124,7 +124,7 @@ impl Binder { Ok(plan) } - fn bind_single_table_ref(&mut self, table: &TableFactor) -> Result<(TableIdx, LogicalPlan)> { + fn bind_single_table_ref(&mut self, table: &TableFactor) -> Result<(TableId, LogicalPlan)> { let plan_with_id = match table { TableFactor::Table { name, alias, .. } => { let obj_name = name @@ -203,16 +203,9 @@ impl Binder { fn bind_all_column_refs(&mut self) -> Result> { let mut exprs = vec![]; - for ref_id in self.context.bind_table.values().cloned().collect_vec() { - let table = self.context.catalog.get_table(ref_id).unwrap(); - for (col_id, col) in table.all_columns() { - let column_ref_id = ColumnRefId::from_table(ref_id, col_id); - // self.record_regular_table_column( - // &table.name(), - // col.name(), - // *col_id, - // col.desc().clone(), - // ); + for table_id in self.context.bind_table.values().cloned().collect_vec() { + let table = self.context.catalog.get_table(&table_id).unwrap(); + for (_, col) in table.all_columns() { exprs.push(ScalarExpression::ColumnRef(col.clone())); } } @@ -220,7 +213,7 @@ impl Binder { Ok(exprs) } - fn bind_join(&mut self, left_id: TableIdx, left: LogicalPlan, join: &Join) -> Result { + fn bind_join(&mut self, left_id: TableId, left: LogicalPlan, join: &Join) -> Result { let Join { relation, join_operator, @@ -237,11 +230,11 @@ impl Binder { _ => unimplemented!(), }; let left_table = self.context.catalog - .get_table(left_id) + .get_table(&left_id) .cloned() .expect("Left table not found"); let right_table = self.context.catalog - .get_table(right_id) + .get_table(&right_id) .cloned() .expect("Right table not found"); @@ -487,7 +480,7 @@ mod tests { let binder = Binder::new(BinderContext::new(root)); let stmt = crate::parser::parse_sql(sql).unwrap(); - binder.bind(&stmt[0]) + Ok(binder.bind(&stmt[0])?.0) } #[test] @@ -535,7 +528,7 @@ mod tests { plan_8 ); - let plan_9 = select_sql_run("select * from t1 inner join t2 on c1 = c3 and c1 > 1")?; + let plan_9 = select_sql_run("select c1, c3 from t1 inner join t2 on c1 = c3 and c1 > 1")?; println!( "join:\n {:#?}", plan_9 diff --git a/src/catalog/column.rs b/src/catalog/column.rs index a2b9c149..20dd2ed5 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -1,11 +1,11 @@ use arrow::datatypes::{DataType, Field}; use sqlparser::ast::{ColumnDef, ColumnOption}; -use crate::types::{ColumnIdx, LogicalType}; +use crate::types::{ColumnId, LogicalType}; #[derive(Debug, Clone, PartialEq)] pub struct ColumnCatalog { - pub id: Option, + pub id: Option, pub name: String, pub nullable: bool, pub desc: ColumnDesc, diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index 7f341085..45e7c5d9 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -4,7 +4,6 @@ use std::sync::Arc; pub(crate) use self::column::*; pub(crate) use self::root::*; pub(crate) use self::table::*; -use crate::types::{ColumnIdx, TableIdx}; /// The type of catalog reference. pub type CatalogRef = Arc; @@ -16,33 +15,6 @@ mod column; mod root; mod table; -/// The reference ID of a column. -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -pub struct ColumnRefId { - pub table_id: TableIdx, - pub column_id: ColumnIdx, -} - -impl ColumnRefId { - pub const fn from_table(table_id: TableIdx, column_id: ColumnIdx) -> Self { - ColumnRefId { - table_id, - column_id, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -pub struct TableRefId { - pub table_id: TableIdx, -} - -impl TableRefId { - pub const fn new(table_id: TableIdx) -> Self { - TableRefId { table_id } - } -} - #[derive(thiserror::Error, Debug)] pub enum CatalogError { #[error("{0} not found: {1}")] diff --git a/src/catalog/root.rs b/src/catalog/root.rs index 896f2a5d..f36d9473 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -1,13 +1,12 @@ use std::collections::BTreeMap; use crate::catalog::{CatalogError, ColumnCatalog, TableCatalog}; -use crate::types::{IdGenerator, TableIdx}; +use crate::types::{IdGenerator, TableId}; #[derive(Debug, Clone)] pub struct RootCatalog { - generator: IdGenerator, - pub table_idxs: BTreeMap, - pub tables: Vec, + table_idxs: BTreeMap, + tables: BTreeMap, } impl Default for RootCatalog { @@ -20,35 +19,34 @@ impl RootCatalog { #[allow(dead_code)] pub fn new() -> Self { RootCatalog { - generator: IdGenerator::new(), table_idxs: Default::default(), tables: Default::default(), } } - pub(crate) fn get_table_id_by_name(&self, name: &str) -> Option { + pub(crate) fn get_table_id_by_name(&self, name: &str) -> Option { self.table_idxs.get(name).cloned() } - pub(crate) fn get_table(&self, table_id: TableIdx) -> Option<&TableCatalog> { + pub(crate) fn get_table(&self, table_id: &TableId) -> Option<&TableCatalog> { self.tables.get(table_id) } pub(crate) fn get_table_by_name(&self, name: &str) -> Option<&TableCatalog> { let id = self.table_idxs.get(name)?; - self.tables.get(*id) + self.tables.get(id) } pub(crate) fn add_table( &mut self, table_name: String, columns: Vec, - ) -> Result { + ) -> Result { if self.table_idxs.contains_key(&table_name) { return Err(CatalogError::Duplicated("column", table_name)); } let mut table = TableCatalog::new(table_name.to_owned(), columns)?; - let table_id = self.generator.build(); + let table_id = IdGenerator::build(); table.id = Some(table_id); self.table_idxs.insert(table_name, table_id); @@ -56,6 +54,12 @@ impl RootCatalog { Ok(table_id) } + + pub(crate) fn tables(&self) -> Vec<(&TableId, &TableCatalog)> { + self.tables + .iter() + .collect() + } } #[cfg(test)] diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 5024208a..1157922a 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -1,19 +1,18 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; use std::sync::Arc; use arrow::datatypes::{Schema, SchemaRef}; use itertools::Itertools; use crate::catalog::{CatalogError, ColumnCatalog}; -use crate::types::{ColumnIdx, IdGenerator, TableIdx}; +use crate::types::{ColumnId, IdGenerator, TableId}; #[derive(Debug, Clone, PartialEq)] pub struct TableCatalog { - pub id: Option, + pub id: Option, pub name: String, - generator: IdGenerator, /// Mapping from column names to column ids - column_idxs: HashMap, - pub(crate) columns: Vec, + column_idxs: BTreeMap, + columns: BTreeMap, } impl TableCatalog { @@ -21,34 +20,33 @@ impl TableCatalog { self.columns.len() } - pub(crate) fn get_column_by_id(&self, id: ColumnIdx) -> Option<&ColumnCatalog> { + pub(crate) fn get_column_by_id(&self, id: &ColumnId) -> Option<&ColumnCatalog> { self.columns.get(id) } - pub(crate) fn get_column_id_by_name(&self, name: &str) -> Option { + pub(crate) fn get_column_id_by_name(&self, name: &str) -> Option { self.column_idxs.get(name).cloned() } pub(crate) fn get_column_by_name(&self, name: &str) -> Option<&ColumnCatalog> { let id = self.column_idxs.get(name)?; - self.columns.get(*id) + self.columns.get(id) } pub(crate) fn contains_column(&self, name: &str) -> bool { self.column_idxs.contains_key(name) } - pub(crate) fn all_columns(&self) -> Vec<(ColumnIdx, &ColumnCatalog)> { + pub(crate) fn all_columns(&self) -> Vec<(&ColumnId, &ColumnCatalog)> { self.columns .iter() - .enumerate() - .collect_vec() + .collect() } // TODO: 缓存schema pub(crate) fn schema(&self) -> SchemaRef { let fields = self.columns.iter() - .map(ColumnCatalog::to_field) + .map(|(_, col)| col.to_field()) .collect_vec(); Arc::new(Schema::new(fields)) } @@ -57,12 +55,12 @@ impl TableCatalog { pub(crate) fn add_column( &mut self, mut col_catalog: ColumnCatalog, - ) -> Result { + ) -> Result { if self.column_idxs.contains_key(&col_catalog.name) { return Err(CatalogError::Duplicated("column", col_catalog.name.into())); } - let col_id = self.generator.build(); + let col_id = IdGenerator::build(); col_catalog.id = Some(col_id); self.column_idxs.insert(col_catalog.name.to_owned(), col_id); @@ -78,9 +76,8 @@ impl TableCatalog { let mut table_catalog = TableCatalog { id: None, name: table_name, - generator: IdGenerator::new(), - column_idxs: HashMap::new(), - columns: Vec::new(), + column_idxs: BTreeMap::new(), + columns: BTreeMap::new(), }; for col_catalog in columns.into_iter() { @@ -116,11 +113,11 @@ mod tests { let col_b_id = table_catalog.get_column_id_by_name("b").unwrap(); assert!(col_a_id < col_b_id); - let column_catalog = table_catalog.get_column_by_id(col_a_id).unwrap(); + let column_catalog = table_catalog.get_column_by_id(&col_a_id).unwrap(); assert_eq!(column_catalog.name, "a"); assert_eq!(*column_catalog.datatype(), LogicalType::Integer,); - let column_catalog = table_catalog.get_column_by_id(col_b_id).unwrap(); + let column_catalog = table_catalog.get_column_by_id(&col_b_id).unwrap(); assert_eq!(column_catalog.name, "b"); assert_eq!(*column_catalog.datatype(), LogicalType::Boolean,); } diff --git a/src/db.rs b/src/db.rs index e43e0dc1..1f2211b0 100644 --- a/src/db.rs +++ b/src/db.rs @@ -44,10 +44,10 @@ impl Database { /// Sort(a) /// Limit(1) /// Project(a,b) - let logical_plan = binder.bind(&stmts[0])?; + let (logical_plan, bind_context) = binder.bind(&stmts[0])?; // println!("logic plan: {:#?}", logical_plan); - let mut builder = PhysicalPlanBuilder::new(); + let mut builder = PhysicalPlanBuilder::new(bind_context); let operator = builder.build_plan(&logical_plan)?; // println!("operator: {:#?}", operator); @@ -102,9 +102,9 @@ mod test { use crate::catalog::{ColumnCatalog, ColumnDesc}; use crate::db::Database; use crate::storage::{Storage, StorageError}; - use crate::types::{LogicalType, TableIdx}; + use crate::types::{LogicalType, TableId}; - fn build_table(storage: &impl Storage) -> Result { + fn build_table(storage: &impl Storage) -> Result { let schema = Arc::new(Schema::new( vec![ ColumnCatalog::new( @@ -155,7 +155,7 @@ mod test { let _ = kipsql.run("insert into t2 values (1, 2), (2, 3), (5, 6)").await?; println!("full:"); - let vec_batch_full_fields = kipsql.run("select * from t1 right join t2 on a = c").await?; + let vec_batch_full_fields = kipsql.run("select * from t1").await?; print_batches(&vec_batch_full_fields)?; println!("projection_and_filter:"); @@ -171,19 +171,19 @@ mod test { print_batches(&vec_batch_limit)?; println!("inner join:"); - let vec_batch_inner_join = kipsql.run("select * from t1 inner join t2 on a = c and c > 1").await?; + let vec_batch_inner_join = kipsql.run("select * from t1 inner join t2 on a = c").await?; print_batches(&vec_batch_inner_join)?; println!("left join:"); - let vec_batch_left_join = kipsql.run("select * from t1 left join t2 on a = c and c > 1").await?; + let vec_batch_left_join = kipsql.run("select * from t1 left join t2 on a = c").await?; print_batches(&vec_batch_left_join)?; println!("right join:"); - let vec_batch_right_join = kipsql.run("select * from t1 right join t2 on a = c and c > 1").await?; + let vec_batch_right_join = kipsql.run("select * from t1 right join t2 on a = c and a > 1").await?; print_batches(&vec_batch_right_join)?; println!("full join:"); - let vec_batch_full_join = kipsql.run("select * from t1 full join t2 on a = c and c > 1").await?; + let vec_batch_full_join = kipsql.run("select d, b from t1 full join t2 on a = c and a > 1").await?; print_batches(&vec_batch_full_join)?; Ok(()) diff --git a/src/execution_v1/physical_plan/physical_plan_builder.rs b/src/execution_v1/physical_plan/physical_plan_builder.rs index b780b201..80d99347 100644 --- a/src/execution_v1/physical_plan/physical_plan_builder.rs +++ b/src/execution_v1/physical_plan/physical_plan_builder.rs @@ -1,3 +1,5 @@ +use std::mem; +use ahash::HashMap; use crate::execution_v1::physical_plan::physical_create_table::PhysicalCreateTable; use crate::execution_v1::physical_plan::physical_projection::PhysicalProjection; use crate::execution_v1::physical_plan::physical_table_scan::PhysicalTableScan; @@ -7,31 +9,56 @@ use crate::planner::operator::Operator; use crate::planner::LogicalPlan; use anyhow::anyhow; use anyhow::Result; +use itertools::Itertools; +use crate::binder::BinderContext; use crate::execution_v1::physical_plan::physical_filter::PhysicalFilter; use crate::execution_v1::physical_plan::physical_hash_join::PhysicalHashJoin; use crate::execution_v1::physical_plan::physical_insert::PhysicalInsert; use crate::execution_v1::physical_plan::physical_limit::PhysicalLimit; use crate::execution_v1::physical_plan::physical_sort::PhysicalSort; use crate::execution_v1::physical_plan::physical_values::PhysicalValues; +use crate::expression::ScalarExpression; use crate::planner::operator::create_table::CreateTableOperator; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::insert::InsertOperator; -use crate::planner::operator::join::{JoinOperator, JoinType}; +use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; use crate::planner::operator::limit::LimitOperator; use crate::planner::operator::project::ProjectOperator; -use crate::planner::operator::sort::SortOperator; +use crate::planner::operator::sort::{SortField, SortOperator}; use crate::planner::operator::values::ValuesOperator; +use crate::types::ColumnId; -pub struct PhysicalPlanBuilder { } +pub struct PhysicalPlanBuilder { + column_index: HashMap +} impl PhysicalPlanBuilder { - pub fn new() -> Self { - PhysicalPlanBuilder { } + pub fn new(context: BinderContext) -> Self { + let mut pos = 0usize; + let root = &context.catalog; + let column_index = context.bind_table + .iter() + .filter_map(|(_, table_id)| { + root.get_table(table_id) + .map(|table| { + table.all_columns() + .iter() + .map(|(col_id, _)| { + let next_pos = pos + 1; + (**col_id, mem::replace(&mut pos, next_pos)) + }) + .collect_vec() + }) + }) + .flatten() + .collect(); + + PhysicalPlanBuilder { column_index } } pub fn build_plan(&mut self, plan: &LogicalPlan) -> Result { match &plan.operator { - Operator::Project(op) => self.build_physical_select_projection(plan, op), + Operator::Project(op) => self.build_physical_projection(plan, op), Operator::Scan(scan) => Ok(self.build_physical_scan(scan.clone())), Operator::Filter(op) => self.build_physical_filter(plan, op), Operator::CreateTable(op) => Ok(self.build_physical_create_table(op)), @@ -71,11 +98,16 @@ impl PhysicalPlanBuilder { ) } - fn build_physical_select_projection(&mut self, plan: &LogicalPlan, op: &ProjectOperator) -> Result { + fn build_physical_projection(&mut self, plan: &LogicalPlan, op: &ProjectOperator) -> Result { let input = self.build_plan(plan.child(0)?)?; + let exprs = op.columns + .iter() + .map(|expr| self.rewriter_expr(expr)) + .collect_vec(); + Ok(PhysicalPlan::Projection(PhysicalProjection { - exprs: op.columns.clone(), + exprs, input: Box::new(input), })) } @@ -88,16 +120,30 @@ impl PhysicalPlanBuilder { let input = self.build_plan(plan.child(0)?)?; Ok(PhysicalPlan::Filter(PhysicalFilter { - predicate: base.predicate.clone(), + predicate: self.rewriter_expr(&base.predicate), input: Box::new(input), })) } - fn build_physical_sort(&mut self, plan: &LogicalPlan, base: &SortOperator) -> Result { + fn build_physical_sort(&mut self, plan: &LogicalPlan, SortOperator { sort_fields, limit }: &SortOperator) -> Result { let input = self.build_plan(plan.child(0)?)?; + let rewrite_sort_fields = sort_fields + .into_iter() + .map(|SortField{ expr, desc, nulls_first }| { + SortField { + expr: self.rewriter_expr(expr), + desc: desc.clone(), + nulls_first: nulls_first.clone(), + } + }) + .collect_vec(); + Ok(PhysicalPlan::Sort(PhysicalSort { - op: base.clone(), + op: SortOperator { + sort_fields: rewrite_sort_fields, + limit: limit.clone(), + }, input: Box::new(input), })) } @@ -111,18 +157,94 @@ impl PhysicalPlanBuilder { })) } - fn build_physical_join(&mut self, plan: &LogicalPlan, base: &JoinOperator) -> Result { + fn build_physical_join(&mut self, plan: &LogicalPlan, JoinOperator{ on, join_type } : &JoinOperator) -> Result { let left_input = Box::new(self.build_plan(plan.child(0)?)?); let right_input = Box::new(self.build_plan(plan.child(1)?)?); - if base.join_type == JoinType::Cross { + let on = if let JoinCondition::On { on, filter } = on { + let rewrite_on = on.iter() + .map(|(left_expr, right_expr)| { + (self.rewriter_expr(left_expr), self.rewriter_expr(right_expr)) + }) + .collect_vec(); + let filter = filter + .as_ref() + .map(|expr| self.rewriter_expr(expr)); + + JoinCondition::On { on: rewrite_on, filter } + } else { + JoinCondition::None + }; + + if join_type == &JoinType::Cross { todo!() } else { Ok(PhysicalPlan::HashJoin(PhysicalHashJoin { - op: base.clone(), + op: JoinOperator { + on, + join_type: join_type.clone(), + }, left_input, right_input, })) } } + + fn rewriter_expr(&mut self, expr: &ScalarExpression) -> ScalarExpression { + match expr { + ScalarExpression::ColumnRef(col) => { + ScalarExpression::InputRef { + // FIXME: remove unwrap + index: *self.column_index.get(&col.id.unwrap()).unwrap(), + ty: col.datatype().clone(), + } + } + ScalarExpression::Alias { expr, alias } => { + ScalarExpression::Alias { + expr: Box::new(self.rewriter_expr(expr)), + alias: alias.clone() + } + } + ScalarExpression::TypeCast { expr, ty, is_try } => { + ScalarExpression::TypeCast { + expr: Box::new(self.rewriter_expr(expr)), + ty: ty.clone(), + is_try: is_try.clone(), + } + } + ScalarExpression::IsNull { expr } => { + ScalarExpression::IsNull { + expr: Box::new(self.rewriter_expr(expr)) + } + } + ScalarExpression::Unary { op, expr, ty } => { + ScalarExpression::Unary { + op: op.clone(), + expr: Box::new(self.rewriter_expr(expr)), + ty: ty.clone(), + } + } + ScalarExpression::Binary { op, left_expr, right_expr, ty } => { + ScalarExpression::Binary { + op: op.clone(), + left_expr: Box::new(self.rewriter_expr(left_expr)), + right_expr: Box::new(self.rewriter_expr(right_expr)), + ty: ty.clone(), + } + } + ScalarExpression::AggCall { kind, args, ty } => { + let rewrite_args = args + .into_iter() + .map(|expr| self.rewriter_expr(expr)) + .collect_vec(); + + ScalarExpression::AggCall { + kind: kind.clone(), + args: rewrite_args, + ty: ty.clone(), + } + } + _ => expr.clone() + } + } } \ No newline at end of file diff --git a/src/execution_v1/volcano_executor/insert.rs b/src/execution_v1/volcano_executor/insert.rs index cf556024..7c76cda7 100644 --- a/src/execution_v1/volcano_executor/insert.rs +++ b/src/execution_v1/volcano_executor/insert.rs @@ -38,7 +38,7 @@ impl Insert { let new_batch = RecordBatch::try_new(table.schema(), full_arrays)?; - storage.get_table(table.id.unwrap())?.append(new_batch)?; + storage.get_table(&table.id.unwrap())?.append(new_batch)?; } } else { Err(CatalogError::NotFound("root", table_name.to_string()))?; diff --git a/src/execution_v1/volcano_executor/projection.rs b/src/execution_v1/volcano_executor/projection.rs index a2256387..aac6fb7b 100644 --- a/src/execution_v1/volcano_executor/projection.rs +++ b/src/execution_v1/volcano_executor/projection.rs @@ -10,24 +10,19 @@ pub struct Projection { } impl Projection { #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] pub async fn execute(exprs: Vec, input: BoxedExecutor) { - // FIXME: 支持JOIN投射 - // #[for_await] - // for batch in input { - // let batch = batch?; - // let columns = exprs - // .iter() - // .map(|e| e.eval_column(&batch)) - // .try_collect(); - // let fields = exprs.iter().map(|e| e.eval_field(&batch)).collect(); - // let schema = SchemaRef::new(Schema::new_with_metadata( - // fields, - // batch.schema().metadata().clone(), - // )); - // yield RecordBatch::try_new(schema, columns?)?; - // } #[for_await] for batch in input { - yield batch?; + let batch = batch?; + let columns = exprs + .iter() + .map(|e| e.eval_column(&batch)) + .try_collect(); + let fields = exprs.iter().map(|e| e.eval_field(&batch)).collect(); + let schema = SchemaRef::new(Schema::new_with_metadata( + fields, + batch.schema().metadata().clone(), + )); + yield RecordBatch::try_new(schema, columns?)?; } } } \ No newline at end of file diff --git a/src/execution_v1/volcano_executor/sort.rs b/src/execution_v1/volcano_executor/sort.rs index 96f9b853..e6e089ee 100644 --- a/src/execution_v1/volcano_executor/sort.rs +++ b/src/execution_v1/volcano_executor/sort.rs @@ -29,6 +29,7 @@ impl Sort { .into_iter() .map(|SortField { expr, desc, nulls_first }| -> Result { let values = expr.eval_column(&batch)?; + Ok(SortColumn { values, options: Some(SortOptions { diff --git a/src/execution_v1/volcano_executor/table_scan.rs b/src/execution_v1/volcano_executor/table_scan.rs index 49e14904..bee2533c 100644 --- a/src/execution_v1/volcano_executor/table_scan.rs +++ b/src/execution_v1/volcano_executor/table_scan.rs @@ -12,9 +12,9 @@ impl TableScan { #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] pub async fn execute(plan: PhysicalTableScan, storage: impl Storage) { // TODO: sort_fields, pre_where, limit - let ScanOperator { table_ref_id, .. } = plan.base; + let ScanOperator { table_id, .. } = plan.base; - let table = storage.get_table(table_ref_id)?; + let table = storage.get_table(&table_id)?; let mut transaction = table.read( None, None diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 9cda6723..81f31160 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -12,12 +12,9 @@ impl ScalarExpression { match &self { ScalarExpression::Constant(val) => Ok(val.to_array_of_size(batch.num_rows())), - ScalarExpression::ColumnRef(col) => { - let index = col.id.expect("The Column does not belong to the Table"); - Ok(batch.column(index).clone()) - } + ScalarExpression::ColumnRef(_) => unreachable!("column ref should be resolved"), ScalarExpression::InputRef{ index, .. } => - Ok(batch.column(*index).clone()), + Ok(batch.column(*index % batch.num_columns()).clone()), ScalarExpression::Alias{ expr, .. } => expr.eval_column(batch), ScalarExpression::TypeCast{ expr, ty, .. } => diff --git a/src/planner/operator/join.rs b/src/planner/operator/join.rs index 01a8ca8b..51e9a5ac 100644 --- a/src/planner/operator/join.rs +++ b/src/planner/operator/join.rs @@ -3,7 +3,7 @@ use crate::planner::LogicalPlan; use super::Operator; -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, Copy)] pub enum JoinType { Inner, Left, diff --git a/src/planner/operator/scan.rs b/src/planner/operator/scan.rs index 6829d37f..17988698 100644 --- a/src/planner/operator/scan.rs +++ b/src/planner/operator/scan.rs @@ -1,13 +1,13 @@ -use crate::types::TableIdx; -use crate::{catalog::ColumnRefId, expression::ScalarExpression}; +use crate::types::{ColumnId, TableId}; +use crate::expression::ScalarExpression; use crate::planner::LogicalPlan; use super::{sort::SortField, Operator}; #[derive(Debug, PartialEq, Clone)] pub struct ScanOperator { - pub table_ref_id: TableIdx, - pub columns: Vec, + pub table_id: TableId, + pub columns: Vec, pub sort_fields: Vec, // Support push down predicate. // If pre_where is simple predicate, for example: a > 1 then can calculate directly when read data. @@ -16,10 +16,10 @@ pub struct ScanOperator { pub limit: Option, } impl ScanOperator { - pub fn new(table_ref_id: TableIdx) -> LogicalPlan { + pub fn new(table_id: TableId) -> LogicalPlan { LogicalPlan { operator: Operator::Scan(ScanOperator { - table_ref_id, + table_id, columns: vec![], sort_fields: vec![], pre_where: vec![], diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 9dc49bcc..dccf412a 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::sync::Arc; use arrow::record_batch::RecordBatch; @@ -5,7 +6,7 @@ use parking_lot::Mutex; use crate::catalog::{ColumnCatalog, ColumnDesc, RootCatalog}; use crate::storage::{Bounds, Projections, Storage, StorageError, Table, Transaction}; -use crate::types::{LogicalType, TableIdx}; +use crate::types::{LogicalType, TableId}; #[derive(Debug)] pub struct InMemoryStorage { @@ -15,7 +16,7 @@ pub struct InMemoryStorage { #[derive(Debug)] struct StorageInner { catalog: RootCatalog, - tables: Vec, + tables: BTreeMap, } impl Default for InMemoryStorage { @@ -30,7 +31,7 @@ impl InMemoryStorage { inner: Arc::new(Mutex::new( StorageInner { catalog: RootCatalog::default(), - tables: Vec::new(), + tables: BTreeMap::new(), }) ) } @@ -52,7 +53,7 @@ impl Storage for InMemoryStorage { &self, table_name: &str, data: Vec, - ) -> Result { + ) -> Result { let mut table = InMemoryTable::new(table_name, data)?; let mut inner = self.inner.lock(); @@ -67,12 +68,12 @@ impl Storage for InMemoryStorage { Ok(table_id) } - fn get_table(&self, id: TableIdx) -> Result { + fn get_table(&self, id: &TableId) -> Result { self.inner.lock() .tables .get(id) .cloned() - .ok_or(StorageError::TableNotFound(id)) + .ok_or(StorageError::TableNotFound(*id)) } fn get_catalog(&self) -> RootCatalog { @@ -87,7 +88,7 @@ impl Storage for InMemoryStorage { #[derive(Debug, Clone)] pub struct InMemoryTable { - table_id: TableIdx, + table_id: TableId, table_name: String, inner: Arc> } @@ -214,7 +215,7 @@ mod storage_test { assert!(table_catalog.is_some()); assert!(table_catalog.unwrap().get_column_id_by_name("a").is_some()); - let table = storage.get_table(id)?; + let table = storage.get_table(&id)?; let mut tx = table.read(None, None)?; let batch = tx.next_batch()?; println!("{:?}", batch); diff --git a/src/storage/mod.rs b/src/storage/mod.rs index a89d6b66..3811c164 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -7,7 +7,7 @@ use arrow::record_batch::RecordBatch; use crate::catalog::{CatalogError, RootCatalog}; use crate::storage::memory::InMemoryStorage; -use crate::types::TableIdx; +use crate::types::TableId; #[derive(Debug)] pub enum StorageImpl { @@ -21,8 +21,8 @@ pub trait Storage: Sync + Send + 'static { &self, table_name: &str, columns: Vec, - ) -> Result; - fn get_table(&self, id: TableIdx) -> Result; + ) -> Result; + fn get_table(&self, id: &TableId) -> Result; fn get_catalog(&self) -> RootCatalog; fn show_tables(&self) -> Result; } @@ -60,7 +60,7 @@ pub enum StorageError { IoError(#[from] io::Error), #[error("table not found: {0}")] - TableNotFound(TableIdx), + TableNotFound(TableId), #[error("catalog error")] CatalogError(#[from] CatalogError), diff --git a/src/types/mod.rs b/src/types/mod.rs index 8e6ec90f..e749a7c7 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,7 +1,8 @@ pub mod errors; pub mod value; -use std::mem; +use std::sync::atomic::AtomicU32; +use std::sync::atomic::Ordering::{Acquire, Release}; use arrow::datatypes::IntervalUnit; use integer_encoding::FixedInt; @@ -9,34 +10,32 @@ use strum_macros::AsRefStr; use crate::types::errors::TypeError; -#[derive(Debug, Clone, PartialEq)] -pub(crate) struct IdGenerator { - buf: usize -} +static ID_BUF: AtomicU32 = AtomicU32::new(0); + +pub(crate) struct IdGenerator { } impl IdGenerator { - pub(crate) fn encode_to_raw(&self) -> Vec { - self.buf.encode_fixed_vec() + pub(crate) fn encode_to_raw() -> Vec { + ID_BUF + .load(Acquire) + .encode_fixed_vec() } - pub(crate) fn decode_from_raw(&mut self, buf: &[u8]) { - self.buf = u32::decode_fixed(buf) as usize; + pub(crate) fn from_raw(buf: &[u8]) { + Self::init(u32::decode_fixed(buf)) } - pub(crate) fn new() -> Self { - IdGenerator { - buf: 0 - } + pub(crate) fn init(init_value: u32) { + ID_BUF.store(init_value, Release) } - pub(crate) fn build(&mut self) -> usize { - let next_idx = self.buf + 1; - mem::replace(&mut self.buf, next_idx) + pub(crate) fn build() -> u32 { + ID_BUF.fetch_add(1, Release) } } -pub type TableIdx = usize; -pub type ColumnIdx = usize; +pub type TableId = u32; +pub type ColumnId = u32; /// Sqlrs type conversion: /// sqlparser::ast::DataType -> LogicalType -> arrow::datatypes::DataType @@ -348,25 +347,29 @@ impl std::fmt::Display for LogicalType { #[cfg(test)] mod test { + use std::sync::atomic::Ordering::Release; - use crate::types::IdGenerator; + use crate::types::{IdGenerator, ID_BUF}; + /// Tips: 由于IdGenerator为static全局性质生成的id,因此需要单独测试避免其他测试方法干扰 #[test] + #[ignore] fn test_id_generator() { - let mut generator_1 = IdGenerator::new(); + assert_eq!(IdGenerator::build(), 0); + assert_eq!(IdGenerator::build(), 1); - assert_eq!(generator_1.build(), 0); - assert_eq!(generator_1.build(), 1); + let buf = IdGenerator::encode_to_raw(); + test_id_generator_reset(); - let buf = generator_1.encode_to_raw(); + assert_eq!(IdGenerator::build(), 0); - let mut generator_2 = IdGenerator::new(); + IdGenerator::from_raw(&buf); - assert_eq!(generator_2.build(), 0); - - generator_2.decode_from_raw(&buf); + assert_eq!(IdGenerator::build(), 2); + assert_eq!(IdGenerator::build(), 3); + } - assert_eq!(generator_2.build(), 2); - assert_eq!(generator_2.build(), 3); + fn test_id_generator_reset() { + ID_BUF.store(0, Release) } } From 6d91df2e4fdfe58f4def1561bed257dd25f15409 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 7 Aug 2023 16:09:44 +0800 Subject: [PATCH 4/7] fix(insert): fix the misalignment between the specified field and the data when inserting --- src/db.rs | 14 ++++++++----- src/execution_v1/volcano_executor/insert.rs | 23 +++++++++++---------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/db.rs b/src/db.rs index 1f2211b0..14951095 100644 --- a/src/db.rs +++ b/src/db.rs @@ -151,12 +151,16 @@ mod test { tokio_test::block_on(async move { let _ = kipsql.run("create table t1 (a int, b int)").await?; let _ = kipsql.run("create table t2 (c int, d int)").await?; - let _ = kipsql.run("insert into t1 values (1, 1), (3, 3), (5, 4)").await?; - let _ = kipsql.run("insert into t2 values (1, 2), (2, 3), (5, 6)").await?; + let _ = kipsql.run("insert into t1 (b, a) values (1, 1), (3, 3), (5, 4)").await?; + let _ = kipsql.run("insert into t2 (d, c) values (1, 2), (2, 3), (5, 6)").await?; - println!("full:"); - let vec_batch_full_fields = kipsql.run("select * from t1").await?; - print_batches(&vec_batch_full_fields)?; + println!("full t1:"); + let vec_batch_full_fields_t1 = kipsql.run("select * from t1").await?; + print_batches(&vec_batch_full_fields_t1)?; + + println!("full t2:"); + let vec_batch_full_fields_t2 = kipsql.run("select * from t2").await?; + print_batches(&vec_batch_full_fields_t2)?; println!("projection_and_filter:"); let vec_batch_projection_a = kipsql.run("select a from t1 where a <= b").await?; diff --git a/src/execution_v1/volcano_executor/insert.rs b/src/execution_v1/volcano_executor/insert.rs index 7c76cda7..737d3408 100644 --- a/src/execution_v1/volcano_executor/insert.rs +++ b/src/execution_v1/volcano_executor/insert.rs @@ -1,5 +1,6 @@ -use arrow::array::{Array, new_null_array}; -use arrow::datatypes::DataType; +use ahash::HashMap; +use arrow::array::{Array, ArrayRef, new_null_array}; +use arrow::datatypes::{DataType, Field}; use arrow::record_batch::RecordBatch; use futures_async_stream::try_stream; use itertools::Itertools; @@ -20,19 +21,19 @@ impl Insert { let projection_schema = batch.schema(); let fields = projection_schema.fields(); - let mut arrays = batch.columns().to_vec(); - let col_len = arrays[0].len(); + let arrays = batch.columns(); - arrays.reverse(); + let col_len = arrays[0].len(); + let insert_values: HashMap = (0..fields.len()).into_iter() + .map(|i| (fields[i].clone(), arrays[i].clone())) + .collect(); let full_arrays = table.all_columns() .into_iter() - .map(|(_, col_catalog)| { - if fields.contains(&col_catalog.to_field()) { - arrays.pop().unwrap() - } else { - new_null_array(&DataType::from(col_catalog.datatype().clone()), col_len) - } + .filter_map(|(_, col_catalog)| { + insert_values.get(&col_catalog.to_field()) + .map(ArrayRef::clone) + .or_else(|| Some(new_null_array(&DataType::from(col_catalog.datatype().clone()), col_len))) }) .collect_vec(); From 08ee27ab424843285b77510fff7ba1e6c6d49f8d Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 7 Aug 2023 20:58:14 +0800 Subject: [PATCH 5/7] refactor(physical_plan_builder): for subsequent Agg compatibility, cancel converting InputRef to ColumnRef --- src/binder/aggregate.rs | 6 +- src/binder/mod.rs | 3 +- src/binder/select.rs | 19 +++--- src/catalog/column.rs | 8 ++- src/catalog/root.rs | 8 ++- src/catalog/table.rs | 8 +-- .../physical_plan/physical_plan_builder.rs | 61 +++++++++++-------- .../volcano_executor/hash_join.rs | 1 - src/execution_v1/volcano_executor/insert.rs | 2 +- src/expression/evaluator.rs | 5 +- 10 files changed, 66 insertions(+), 55 deletions(-) diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index b38681f5..13708f50 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -90,11 +90,7 @@ impl Binder { ScalarExpression::AggCall { ty: return_type, .. } => { - let index = if self.context.agg_calls.len() == 0 { - 0 - } else { - self.context.agg_calls.len() + 1 - }; + let index = self.context.agg_calls.len(); let input_ref = ScalarExpression::InputRef { index, ty: return_type.clone(), diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 4c581737..801c0ffc 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -12,11 +12,12 @@ use sqlparser::ast::{Ident, ObjectName, SetExpr, Statement}; use crate::catalog::{RootCatalog, DEFAULT_SCHEMA_NAME, CatalogError}; use crate::expression::ScalarExpression; use crate::planner::LogicalPlan; +use crate::planner::operator::join::JoinType; use crate::types::TableId; #[derive(Debug, Clone)] pub struct BinderContext { pub(crate) catalog: RootCatalog, - pub(crate) bind_table: BTreeMap, + pub(crate) bind_table: BTreeMap)>, aliases: BTreeMap, group_by_exprs: Vec, agg_calls: Vec, diff --git a/src/binder/select.rs b/src/binder/select.rs index 11c79fb9..9b5d9416 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -114,7 +114,7 @@ impl Binder { let TableWithJoins { relation, joins } = &from[0]; - let (left_id, mut plan) = self.bind_single_table_ref(relation)?; + let (left_id, mut plan) = self.bind_single_table_ref(relation, None)?; if !joins.is_empty() { for join in joins { @@ -124,7 +124,7 @@ impl Binder { Ok(plan) } - fn bind_single_table_ref(&mut self, table: &TableFactor) -> Result<(TableId, LogicalPlan)> { + fn bind_single_table_ref(&mut self, table: &TableFactor, joint_type: Option) -> Result<(TableId, LogicalPlan)> { let plan_with_id = match table { TableFactor::Table { name, alias, .. } => { let obj_name = name @@ -157,7 +157,7 @@ impl Binder { .get_table_id_by_name(table) .ok_or_else(|| anyhow::Error::msg(format!("bind table {}", table)))?; - self.context.bind_table.insert(table.into(), table_ref_id); + self.context.bind_table.insert(table.into(), (table_ref_id, joint_type)); (table_ref_id, ScanOperator::new(table_ref_id)) } @@ -203,7 +203,7 @@ impl Binder { fn bind_all_column_refs(&mut self) -> Result> { let mut exprs = vec![]; - for table_id in self.context.bind_table.values().cloned().collect_vec() { + for (table_id, _) in self.context.bind_table.values().cloned() { let table = self.context.catalog.get_table(&table_id).unwrap(); for (_, col) in table.all_columns() { exprs.push(ScalarExpression::ColumnRef(col.clone())); @@ -219,9 +219,7 @@ impl Binder { join_operator, } = join; - let (right_id, right) = self.bind_single_table_ref(relation)?; - - let join_type = match join_operator { + let (join_type, joint_condition) = match join_operator { JoinOperator::Inner(constraint) => (JoinType::Inner, Some(constraint)), JoinOperator::LeftOuter(constraint) => (JoinType::Left, Some(constraint)), JoinOperator::RightOuter(constraint) => (JoinType::Right, Some(constraint)), @@ -229,6 +227,9 @@ impl Binder { JoinOperator::CrossJoin => (JoinType::Cross, None), _ => unimplemented!(), }; + + let (right_id, right) = self.bind_single_table_ref(relation, Some(join_type))?; + let left_table = self.context.catalog .get_table(&left_id) .cloned() @@ -238,7 +239,7 @@ impl Binder { .cloned() .expect("Right table not found"); - let on = match join_type.1 { + let on = match joint_condition { Some(constraint) => self.bind_join_constraint( &left_table, &right_table, @@ -247,7 +248,7 @@ impl Binder { None => JoinCondition::None, }; - Ok(LJoinOperator::new(left, right, on, join_type.0)) + Ok(LJoinOperator::new(left, right, on, join_type)) } fn bind_where( diff --git a/src/catalog/column.rs b/src/catalog/column.rs index 20dd2ed5..e8b258c9 100644 --- a/src/catalog/column.rs +++ b/src/catalog/column.rs @@ -1,12 +1,13 @@ use arrow::datatypes::{DataType, Field}; use sqlparser::ast::{ColumnDef, ColumnOption}; -use crate::types::{ColumnId, LogicalType}; +use crate::types::{ColumnId, IdGenerator, LogicalType, TableId}; #[derive(Debug, Clone, PartialEq)] pub struct ColumnCatalog { - pub id: Option, + pub id: ColumnId, pub name: String, + pub table_id: Option, pub nullable: bool, pub desc: ColumnDesc, } @@ -14,8 +15,9 @@ pub struct ColumnCatalog { impl ColumnCatalog { pub(crate) fn new(column_name: String, nullable: bool, column_desc: ColumnDesc) -> ColumnCatalog { ColumnCatalog { - id: None, + id: IdGenerator::build(), name: column_name, + table_id: None, nullable, desc: column_desc, } diff --git a/src/catalog/root.rs b/src/catalog/root.rs index f36d9473..0bfaf2d5 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -45,10 +45,12 @@ impl RootCatalog { if self.table_idxs.contains_key(&table_name) { return Err(CatalogError::Duplicated("column", table_name)); } - let mut table = TableCatalog::new(table_name.to_owned(), columns)?; - let table_id = IdGenerator::build(); + let table = TableCatalog::new( + table_name.to_owned(), + columns + )?; + let table_id = table.id; - table.id = Some(table_id); self.table_idxs.insert(table_name, table_id); self.tables.insert(table_id, table); diff --git a/src/catalog/table.rs b/src/catalog/table.rs index 1157922a..10bc648b 100644 --- a/src/catalog/table.rs +++ b/src/catalog/table.rs @@ -8,7 +8,7 @@ use crate::catalog::{CatalogError, ColumnCatalog}; use crate::types::{ColumnId, IdGenerator, TableId}; #[derive(Debug, Clone, PartialEq)] pub struct TableCatalog { - pub id: Option, + pub id: TableId, pub name: String, /// Mapping from column names to column ids column_idxs: BTreeMap, @@ -60,9 +60,9 @@ impl TableCatalog { return Err(CatalogError::Duplicated("column", col_catalog.name.into())); } - let col_id = IdGenerator::build(); + let col_id = col_catalog.id; - col_catalog.id = Some(col_id); + col_catalog.table_id = Some(self.id); self.column_idxs.insert(col_catalog.name.to_owned(), col_id); self.columns.insert(col_id, col_catalog); @@ -74,7 +74,7 @@ impl TableCatalog { columns: Vec, ) -> Result { let mut table_catalog = TableCatalog { - id: None, + id: IdGenerator::build(), name: table_name, column_idxs: BTreeMap::new(), columns: BTreeMap::new(), diff --git a/src/execution_v1/physical_plan/physical_plan_builder.rs b/src/execution_v1/physical_plan/physical_plan_builder.rs index 80d99347..07746a51 100644 --- a/src/execution_v1/physical_plan/physical_plan_builder.rs +++ b/src/execution_v1/physical_plan/physical_plan_builder.rs @@ -1,5 +1,4 @@ -use std::mem; -use ahash::HashMap; +use ahash::{HashMap, HashMapExt}; use crate::execution_v1::physical_plan::physical_create_table::PhysicalCreateTable; use crate::execution_v1::physical_plan::physical_projection::PhysicalProjection; use crate::execution_v1::physical_plan::physical_table_scan::PhysicalTableScan; @@ -26,34 +25,41 @@ use crate::planner::operator::limit::LimitOperator; use crate::planner::operator::project::ProjectOperator; use crate::planner::operator::sort::{SortField, SortOperator}; use crate::planner::operator::values::ValuesOperator; -use crate::types::ColumnId; +use crate::types::TableId; pub struct PhysicalPlanBuilder { - column_index: HashMap + table_force_nullable: HashMap + } impl PhysicalPlanBuilder { pub fn new(context: BinderContext) -> Self { - let mut pos = 0usize; - let root = &context.catalog; - let column_index = context.bind_table - .iter() - .filter_map(|(_, table_id)| { - root.get_table(table_id) - .map(|table| { - table.all_columns() - .iter() - .map(|(col_id, _)| { - let next_pos = pos + 1; - (**col_id, mem::replace(&mut pos, next_pos)) - }) - .collect_vec() - }) - }) - .flatten() - .collect(); + let bind_tables = &context.bind_table; + let mut table_force_nullable = HashMap::new(); + let mut left_table_force_nullable = false; + let mut left_table_id = None; - PhysicalPlanBuilder { column_index } + for (table_id, join_option) in bind_tables.values() { + if let Some(join_type) = join_option { + let (left_force_nullable, right_force_nullable) = match join_type { + JoinType::Inner => (false, false), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (true, true), + JoinType::Cross => (true, true), + }; + table_force_nullable.insert(*table_id, right_force_nullable); + left_table_force_nullable = left_force_nullable; + } else { + left_table_id = Some(*table_id); + } + } + + if let Some(id) = left_table_id { + table_force_nullable.insert(id, left_table_force_nullable); + } + + PhysicalPlanBuilder { table_force_nullable } } pub fn build_plan(&mut self, plan: &LogicalPlan) -> Result { @@ -193,11 +199,12 @@ impl PhysicalPlanBuilder { fn rewriter_expr(&mut self, expr: &ScalarExpression) -> ScalarExpression { match expr { ScalarExpression::ColumnRef(col) => { - ScalarExpression::InputRef { - // FIXME: remove unwrap - index: *self.column_index.get(&col.id.unwrap()).unwrap(), - ty: col.datatype().clone(), + let mut new_col = col.clone(); + if let Some(nullable) = self.table_force_nullable.get(&col.table_id.unwrap()) { + new_col.nullable = *nullable; } + + ScalarExpression::ColumnRef(new_col) } ScalarExpression::Alias { expr, alias } => { ScalarExpression::Alias { diff --git a/src/execution_v1/volcano_executor/hash_join.rs b/src/execution_v1/volcano_executor/hash_join.rs index d3cd7a70..33072887 100644 --- a/src/execution_v1/volcano_executor/hash_join.rs +++ b/src/execution_v1/volcano_executor/hash_join.rs @@ -36,7 +36,6 @@ impl HashJoin { let hash_random_state = RandomState::with_seeds(0, 0, 0, 0); let mut join_fields: Vec = Vec::new(); - // FIXME: 应该在Binder层处理,使Project能够同时获取此nullable信息(因为Join会改变主键的nullable判定) let (left_force_nullable, right_force_nullable) = match ty { JoinType::Inner => (false, false), JoinType::Left => (false, true), diff --git a/src/execution_v1/volcano_executor/insert.rs b/src/execution_v1/volcano_executor/insert.rs index 737d3408..e60f2611 100644 --- a/src/execution_v1/volcano_executor/insert.rs +++ b/src/execution_v1/volcano_executor/insert.rs @@ -39,7 +39,7 @@ impl Insert { let new_batch = RecordBatch::try_new(table.schema(), full_arrays)?; - storage.get_table(&table.id.unwrap())?.append(new_batch)?; + storage.get_table(&table.id)?.append(new_batch)?; } } else { Err(CatalogError::NotFound("root", table_name.to_string()))?; diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index 81f31160..09c9ba50 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -12,7 +12,10 @@ impl ScalarExpression { match &self { ScalarExpression::Constant(val) => Ok(val.to_array_of_size(batch.num_rows())), - ScalarExpression::ColumnRef(_) => unreachable!("column ref should be resolved"), + ScalarExpression::ColumnRef(col) => { + let index = batch.schema().index_of(&col.name)?; + Ok(batch.column(index).clone()) + }, ScalarExpression::InputRef{ index, .. } => Ok(batch.column(*index % batch.num_columns()).clone()), ScalarExpression::Alias{ expr, .. } => From 0fcb6e2388ed1462d448233dbe7e7f803c2cf361 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Mon, 7 Aug 2023 21:05:08 +0800 Subject: [PATCH 6/7] fix(create_table): fix `test_create_bind` --- src/binder/create_table.rs | 38 ++++++++++++++------------------------ src/catalog/root.rs | 2 +- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 381a8b23..1895ccd3 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -52,38 +52,28 @@ impl Binder { mod tests { use super::*; use crate::binder::BinderContext; - use crate::catalog::{ColumnCatalog, ColumnDesc, RootCatalog}; - use crate::planner::LogicalPlan; + use crate::catalog::{ColumnDesc, RootCatalog}; use crate::types::LogicalType; #[test] fn test_create_bind() { - let sql = "create table t1 (id int , name varchar(10))"; + let sql = "create table t1 (id int , name varchar(10) null)"; let binder = Binder::new(BinderContext::new(RootCatalog::new())); let stmt = crate::parser::parse_sql(sql).unwrap(); let (plan1, _) = binder.bind(&stmt[0]).unwrap(); - let plan2 = LogicalPlan { - operator: Operator::CreateTable( - CreateTableOperator { - table_name: "t1".to_string(), - columns: vec![ - ColumnCatalog::new( - "id".to_string(), - false, - ColumnDesc::new(LogicalType::Integer, false) - ), - ColumnCatalog::new( - "name".to_string(), - false, - ColumnDesc::new(LogicalType::Varchar, false) - ) - ], - } - ), - childrens: vec![], - }; + match plan1.operator { + Operator::CreateTable(op) => { + assert_eq!(op.table_name, "t1".to_string()); + assert_eq!(op.columns[0].name, "id".to_string()); + assert_eq!(op.columns[0].nullable, false); + assert_eq!(op.columns[0].desc, ColumnDesc::new(LogicalType::Integer, false)); + assert_eq!(op.columns[1].name, "name".to_string()); + assert_eq!(op.columns[1].nullable, true); + assert_eq!(op.columns[1].desc, ColumnDesc::new(LogicalType::Varchar, false)); + } + _ => unreachable!() + } - assert_eq!(plan1, plan2); } } diff --git a/src/catalog/root.rs b/src/catalog/root.rs index 0bfaf2d5..7f5a00b5 100644 --- a/src/catalog/root.rs +++ b/src/catalog/root.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use crate::catalog::{CatalogError, ColumnCatalog, TableCatalog}; -use crate::types::{IdGenerator, TableId}; +use crate::types::TableId; #[derive(Debug, Clone)] pub struct RootCatalog { From 64304977d9ea27cdf6c0d1e8673bc8bfa15c4de4 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Tue, 8 Aug 2023 21:01:53 +0800 Subject: [PATCH 7/7] refactor(binder): when selecting Join in `Binder::bind`, perform corresponding Select nullable processing to avoid rewrite --- src/binder/create_table.rs | 2 +- src/binder/mod.rs | 4 +- src/binder/select.rs | 47 +++++- src/db.rs | 4 +- .../physical_plan/physical_plan_builder.rs | 156 ++---------------- 5 files changed, 65 insertions(+), 148 deletions(-) diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 1895ccd3..0a8b98fb 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -60,7 +60,7 @@ mod tests { let sql = "create table t1 (id int , name varchar(10) null)"; let binder = Binder::new(BinderContext::new(RootCatalog::new())); let stmt = crate::parser::parse_sql(sql).unwrap(); - let (plan1, _) = binder.bind(&stmt[0]).unwrap(); + let plan1 = binder.bind(&stmt[0]).unwrap(); match plan1.operator { Operator::CreateTable(op) => { diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 801c0ffc..54149df8 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -60,7 +60,7 @@ impl Binder { Binder { context } } - pub fn bind(mut self, stmt: &Statement) -> Result<(LogicalPlan, BinderContext)> { + pub fn bind(mut self, stmt: &Statement) -> Result { let plan = match stmt { Statement::Query(query) => self.bind_query(query)?, Statement::CreateTable { name, columns, .. } => self.bind_create_table(name, &columns)?, @@ -73,7 +73,7 @@ impl Binder { } _ => unimplemented!(), }; - Ok((plan, self.context)) + Ok(plan) } } diff --git a/src/binder/select.rs b/src/binder/select.rs index 9b5d9416..8d0f690f 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -1,4 +1,5 @@ use std::borrow::Borrow; +use std::collections::HashMap; use crate::{ expression::ScalarExpression, @@ -62,6 +63,8 @@ impl Binder { let mut select_list = self.normalize_select_item(&select.projection)?; + self.extract_select_join(&mut select_list); + if let Some(predicate) = &select.selection { plan = self.bind_where(plan, predicate)?; } @@ -336,6 +339,48 @@ impl Binder { Ok(LimitOperator::new(offset, limit, children)) } + pub fn extract_select_join( + &mut self, + select_items: &mut [ScalarExpression], + ) { + let bind_tables = &self.context.bind_table; + if bind_tables.len() < 2 { + return; + } + + let mut table_force_nullable = HashMap::new(); + let mut left_table_force_nullable = false; + let mut left_table_id = None; + + for (table_id, join_option) in bind_tables.values() { + if let Some(join_type) = join_option { + let (left_force_nullable, right_force_nullable) = match join_type { + JoinType::Inner => (false, false), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (true, true), + JoinType::Cross => (true, true), + }; + table_force_nullable.insert(*table_id, right_force_nullable); + left_table_force_nullable = left_force_nullable; + } else { + left_table_id = Some(*table_id); + } + } + + if let Some(id) = left_table_id { + table_force_nullable.insert(id, left_table_force_nullable); + } + + for column in select_items { + if let ScalarExpression::ColumnRef(col) = column { + if let Some(nullable) = table_force_nullable.get(&col.table_id.unwrap()) { + col.nullable = *nullable; + } + } + } + } + fn bind_join_constraint( &mut self, left_table: &TableCatalog, @@ -481,7 +526,7 @@ mod tests { let binder = Binder::new(BinderContext::new(root)); let stmt = crate::parser::parse_sql(sql).unwrap(); - Ok(binder.bind(&stmt[0])?.0) + Ok(binder.bind(&stmt[0])?) } #[test] diff --git a/src/db.rs b/src/db.rs index 14951095..7ce3346e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -44,10 +44,10 @@ impl Database { /// Sort(a) /// Limit(1) /// Project(a,b) - let (logical_plan, bind_context) = binder.bind(&stmts[0])?; + let logical_plan = binder.bind(&stmts[0])?; // println!("logic plan: {:#?}", logical_plan); - let mut builder = PhysicalPlanBuilder::new(bind_context); + let mut builder = PhysicalPlanBuilder::new(); let operator = builder.build_plan(&logical_plan)?; // println!("operator: {:#?}", operator); diff --git a/src/execution_v1/physical_plan/physical_plan_builder.rs b/src/execution_v1/physical_plan/physical_plan_builder.rs index 07746a51..c0d223e5 100644 --- a/src/execution_v1/physical_plan/physical_plan_builder.rs +++ b/src/execution_v1/physical_plan/physical_plan_builder.rs @@ -8,63 +8,31 @@ use crate::planner::operator::Operator; use crate::planner::LogicalPlan; use anyhow::anyhow; use anyhow::Result; -use itertools::Itertools; -use crate::binder::BinderContext; use crate::execution_v1::physical_plan::physical_filter::PhysicalFilter; use crate::execution_v1::physical_plan::physical_hash_join::PhysicalHashJoin; use crate::execution_v1::physical_plan::physical_insert::PhysicalInsert; use crate::execution_v1::physical_plan::physical_limit::PhysicalLimit; use crate::execution_v1::physical_plan::physical_sort::PhysicalSort; use crate::execution_v1::physical_plan::physical_values::PhysicalValues; -use crate::expression::ScalarExpression; use crate::planner::operator::create_table::CreateTableOperator; use crate::planner::operator::filter::FilterOperator; use crate::planner::operator::insert::InsertOperator; -use crate::planner::operator::join::{JoinCondition, JoinOperator, JoinType}; +use crate::planner::operator::join::{JoinOperator, JoinType}; use crate::planner::operator::limit::LimitOperator; use crate::planner::operator::project::ProjectOperator; -use crate::planner::operator::sort::{SortField, SortOperator}; +use crate::planner::operator::sort::SortOperator; use crate::planner::operator::values::ValuesOperator; -use crate::types::TableId; -pub struct PhysicalPlanBuilder { - table_force_nullable: HashMap - -} +pub struct PhysicalPlanBuilder { } impl PhysicalPlanBuilder { - pub fn new(context: BinderContext) -> Self { - let bind_tables = &context.bind_table; - let mut table_force_nullable = HashMap::new(); - let mut left_table_force_nullable = false; - let mut left_table_id = None; - - for (table_id, join_option) in bind_tables.values() { - if let Some(join_type) = join_option { - let (left_force_nullable, right_force_nullable) = match join_type { - JoinType::Inner => (false, false), - JoinType::Left => (false, true), - JoinType::Right => (true, false), - JoinType::Full => (true, true), - JoinType::Cross => (true, true), - }; - table_force_nullable.insert(*table_id, right_force_nullable); - left_table_force_nullable = left_force_nullable; - } else { - left_table_id = Some(*table_id); - } - } - - if let Some(id) = left_table_id { - table_force_nullable.insert(id, left_table_force_nullable); - } - - PhysicalPlanBuilder { table_force_nullable } + pub fn new() -> Self { + PhysicalPlanBuilder { } } pub fn build_plan(&mut self, plan: &LogicalPlan) -> Result { match &plan.operator { - Operator::Project(op) => self.build_physical_projection(plan, op), + Operator::Project(op) => self.build_physical_select_projection(plan, op), Operator::Scan(scan) => Ok(self.build_physical_scan(scan.clone())), Operator::Filter(op) => self.build_physical_filter(plan, op), Operator::CreateTable(op) => Ok(self.build_physical_create_table(op)), @@ -104,16 +72,11 @@ impl PhysicalPlanBuilder { ) } - fn build_physical_projection(&mut self, plan: &LogicalPlan, op: &ProjectOperator) -> Result { + fn build_physical_select_projection(&mut self, plan: &LogicalPlan, op: &ProjectOperator) -> Result { let input = self.build_plan(plan.child(0)?)?; - let exprs = op.columns - .iter() - .map(|expr| self.rewriter_expr(expr)) - .collect_vec(); - Ok(PhysicalPlan::Projection(PhysicalProjection { - exprs, + exprs: op.columns.clone(), input: Box::new(input), })) } @@ -126,30 +89,16 @@ impl PhysicalPlanBuilder { let input = self.build_plan(plan.child(0)?)?; Ok(PhysicalPlan::Filter(PhysicalFilter { - predicate: self.rewriter_expr(&base.predicate), + predicate: base.predicate.clone(), input: Box::new(input), })) } - fn build_physical_sort(&mut self, plan: &LogicalPlan, SortOperator { sort_fields, limit }: &SortOperator) -> Result { + fn build_physical_sort(&mut self, plan: &LogicalPlan, base: &SortOperator) -> Result { let input = self.build_plan(plan.child(0)?)?; - let rewrite_sort_fields = sort_fields - .into_iter() - .map(|SortField{ expr, desc, nulls_first }| { - SortField { - expr: self.rewriter_expr(expr), - desc: desc.clone(), - nulls_first: nulls_first.clone(), - } - }) - .collect_vec(); - Ok(PhysicalPlan::Sort(PhysicalSort { - op: SortOperator { - sort_fields: rewrite_sort_fields, - limit: limit.clone(), - }, + op: base.clone(), input: Box::new(input), })) } @@ -163,95 +112,18 @@ impl PhysicalPlanBuilder { })) } - fn build_physical_join(&mut self, plan: &LogicalPlan, JoinOperator{ on, join_type } : &JoinOperator) -> Result { + fn build_physical_join(&mut self, plan: &LogicalPlan, base: &JoinOperator) -> Result { let left_input = Box::new(self.build_plan(plan.child(0)?)?); let right_input = Box::new(self.build_plan(plan.child(1)?)?); - let on = if let JoinCondition::On { on, filter } = on { - let rewrite_on = on.iter() - .map(|(left_expr, right_expr)| { - (self.rewriter_expr(left_expr), self.rewriter_expr(right_expr)) - }) - .collect_vec(); - let filter = filter - .as_ref() - .map(|expr| self.rewriter_expr(expr)); - - JoinCondition::On { on: rewrite_on, filter } - } else { - JoinCondition::None - }; - - if join_type == &JoinType::Cross { + if base.join_type == JoinType::Cross { todo!() } else { Ok(PhysicalPlan::HashJoin(PhysicalHashJoin { - op: JoinOperator { - on, - join_type: join_type.clone(), - }, + op: base.clone(), left_input, right_input, })) } } - - fn rewriter_expr(&mut self, expr: &ScalarExpression) -> ScalarExpression { - match expr { - ScalarExpression::ColumnRef(col) => { - let mut new_col = col.clone(); - if let Some(nullable) = self.table_force_nullable.get(&col.table_id.unwrap()) { - new_col.nullable = *nullable; - } - - ScalarExpression::ColumnRef(new_col) - } - ScalarExpression::Alias { expr, alias } => { - ScalarExpression::Alias { - expr: Box::new(self.rewriter_expr(expr)), - alias: alias.clone() - } - } - ScalarExpression::TypeCast { expr, ty, is_try } => { - ScalarExpression::TypeCast { - expr: Box::new(self.rewriter_expr(expr)), - ty: ty.clone(), - is_try: is_try.clone(), - } - } - ScalarExpression::IsNull { expr } => { - ScalarExpression::IsNull { - expr: Box::new(self.rewriter_expr(expr)) - } - } - ScalarExpression::Unary { op, expr, ty } => { - ScalarExpression::Unary { - op: op.clone(), - expr: Box::new(self.rewriter_expr(expr)), - ty: ty.clone(), - } - } - ScalarExpression::Binary { op, left_expr, right_expr, ty } => { - ScalarExpression::Binary { - op: op.clone(), - left_expr: Box::new(self.rewriter_expr(left_expr)), - right_expr: Box::new(self.rewriter_expr(right_expr)), - ty: ty.clone(), - } - } - ScalarExpression::AggCall { kind, args, ty } => { - let rewrite_args = args - .into_iter() - .map(|expr| self.rewriter_expr(expr)) - .collect_vec(); - - ScalarExpression::AggCall { - kind: kind.clone(), - args: rewrite_args, - ty: ty.clone(), - } - } - _ => expr.clone() - } - } } \ No newline at end of file