diff --git a/benchmarks/expected-plans/q2.txt b/benchmarks/expected-plans/q2.txt index c5f6fb0fd326..e9730550939d 100644 --- a/benchmarks/expected-plans/q2.txt +++ b/benchmarks/expected-plans/q2.txt @@ -1,24 +1,25 @@ Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment - Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value - Inner Join: nation.n_regionkey = region.r_regionkey - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: partsupp.ps_suppkey = supplier.s_suppkey - Inner Join: part.p_partkey = partsupp.ps_partkey - Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") - TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size] - TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] - TableScan: nation projection=[n_nationkey, n_name, n_regionkey] - Filter: region.r_name = Utf8("EUROPE") - TableScan: region projection=[r_regionkey, r_name] - Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1 - Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]] - Inner Join: nation.n_regionkey = region.r_regionkey - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: partsupp.ps_suppkey = supplier.s_suppkey + Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, nation.n_name + Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value + Inner Join: nation.n_regionkey = region.r_regionkey + Inner Join: supplier.s_nationkey = nation.n_nationkey + Inner Join: partsupp.ps_suppkey = supplier.s_suppkey + Inner Join: part.p_partkey = partsupp.ps_partkey + Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") + TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size] TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] - TableScan: nation projection=[n_nationkey, n_name, n_regionkey] - Filter: region.r_name = Utf8("EUROPE") - TableScan: region projection=[r_regionkey, r_name] \ No newline at end of file + TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] + TableScan: nation projection=[n_nationkey, n_name, n_regionkey] + Filter: region.r_name = Utf8("EUROPE") + TableScan: region projection=[r_regionkey, r_name] + Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1 + Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]] + Inner Join: nation.n_regionkey = region.r_regionkey + Inner Join: supplier.s_nationkey = nation.n_nationkey + Inner Join: partsupp.ps_suppkey = supplier.s_suppkey + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] + TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] + TableScan: nation projection=[n_nationkey, n_name, n_regionkey] + Filter: region.r_name = Utf8("EUROPE") + TableScan: region projection=[r_regionkey, r_name] \ No newline at end of file diff --git a/benchmarks/expected-plans/q8.txt b/benchmarks/expected-plans/q8.txt index 3f5a87680831..1b8d08ef875a 100644 --- a/benchmarks/expected-plans/q8.txt +++ b/benchmarks/expected-plans/q8.txt @@ -3,23 +3,24 @@ Sort: all_nations.o_year ASC NULLS LAST Aggregate: groupBy=[[all_nations.o_year]], aggr=[[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)]] Projection: o_year, volume, nation, alias=all_nations Projection: datepart(Utf8("YEAR"), orders.o_orderdate) AS o_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, n2.n_name AS nation - Inner Join: n1.n_regionkey = region.r_regionkey - Inner Join: supplier.s_nationkey = n2.n_nationkey - Inner Join: customer.c_nationkey = n1.n_nationkey - Inner Join: orders.o_custkey = customer.c_custkey - Inner Join: lineitem.l_orderkey = orders.o_orderkey - Inner Join: lineitem.l_suppkey = supplier.s_suppkey - Inner Join: part.p_partkey = lineitem.l_partkey - Filter: part.p_type = Utf8("ECONOMY ANODIZED STEEL") - TableScan: part projection=[p_partkey, p_type] - TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount] - TableScan: supplier projection=[s_suppkey, s_nationkey] - Filter: orders.o_orderdate >= Date32("9131") AND orders.o_orderdate <= Date32("9861") - TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate] - TableScan: customer projection=[c_custkey, c_nationkey] - SubqueryAlias: n1 - TableScan: nation projection=[n_nationkey, n_regionkey] - SubqueryAlias: n2 - TableScan: nation projection=[n_nationkey, n_name] - Filter: region.r_name = Utf8("AMERICA") - TableScan: region projection=[r_regionkey, r_name] \ No newline at end of file + Projection: lineitem.l_extendedprice, lineitem.l_discount, orders.o_orderdate, n2.n_name + Inner Join: n1.n_regionkey = region.r_regionkey + Inner Join: supplier.s_nationkey = n2.n_nationkey + Inner Join: customer.c_nationkey = n1.n_nationkey + Inner Join: orders.o_custkey = customer.c_custkey + Inner Join: lineitem.l_orderkey = orders.o_orderkey + Inner Join: lineitem.l_suppkey = supplier.s_suppkey + Inner Join: part.p_partkey = lineitem.l_partkey + Filter: part.p_type = Utf8("ECONOMY ANODIZED STEEL") + TableScan: part projection=[p_partkey, p_type] + TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount] + TableScan: supplier projection=[s_suppkey, s_nationkey] + Filter: orders.o_orderdate >= Date32("9131") AND orders.o_orderdate <= Date32("9861") + TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate] + TableScan: customer projection=[c_custkey, c_nationkey] + SubqueryAlias: n1 + TableScan: nation projection=[n_nationkey, n_regionkey] + SubqueryAlias: n2 + TableScan: nation projection=[n_nationkey, n_name] + Filter: region.r_name = Utf8("AMERICA") + TableScan: region projection=[r_regionkey, r_name] \ No newline at end of file diff --git a/benchmarks/expected-plans/q9.txt b/benchmarks/expected-plans/q9.txt index 339db70175db..ae7d4f194a8c 100644 --- a/benchmarks/expected-plans/q9.txt +++ b/benchmarks/expected-plans/q9.txt @@ -3,15 +3,16 @@ Sort: profit.nation ASC NULLS LAST, profit.o_year DESC NULLS FIRST Aggregate: groupBy=[[profit.nation, profit.o_year]], aggr=[[SUM(profit.amount)]] Projection: nation, o_year, amount, alias=profit Projection: nation.n_name AS nation, datepart(Utf8("YEAR"), orders.o_orderdate) AS o_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) - CAST(partsupp.ps_supplycost * lineitem.l_quantity AS Decimal128(38, 4)) AS amount - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: lineitem.l_orderkey = orders.o_orderkey - Inner Join: lineitem.l_suppkey = partsupp.ps_suppkey, lineitem.l_partkey = partsupp.ps_partkey - Inner Join: lineitem.l_suppkey = supplier.s_suppkey - Inner Join: part.p_partkey = lineitem.l_partkey - Filter: part.p_name LIKE Utf8("%green%") - TableScan: part projection=[p_partkey, p_name] - TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount] - TableScan: supplier projection=[s_suppkey, s_nationkey] - TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] - TableScan: orders projection=[o_orderkey, o_orderdate] - TableScan: nation projection=[n_nationkey, n_name] \ No newline at end of file + Projection: lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, partsupp.ps_supplycost, orders.o_orderdate, nation.n_name + Inner Join: supplier.s_nationkey = nation.n_nationkey + Inner Join: lineitem.l_orderkey = orders.o_orderkey + Inner Join: lineitem.l_suppkey = partsupp.ps_suppkey, lineitem.l_partkey = partsupp.ps_partkey + Inner Join: lineitem.l_suppkey = supplier.s_suppkey + Inner Join: part.p_partkey = lineitem.l_partkey + Filter: part.p_name LIKE Utf8("%green%") + TableScan: part projection=[p_partkey, p_name] + TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount] + TableScan: supplier projection=[s_suppkey, s_nationkey] + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] + TableScan: orders projection=[o_orderkey, o_orderdate] + TableScan: nation projection=[n_nationkey, n_name] \ No newline at end of file diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 064ef3a35edf..98bf56a02ccf 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -141,28 +141,29 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#; let actual = format!("{}", plan.display_indent()); let expected = r#"Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment - Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value - Inner Join: nation.n_regionkey = region.r_regionkey - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: partsupp.ps_suppkey = supplier.s_suppkey - Inner Join: part.p_partkey = partsupp.ps_partkey - Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") - TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] - TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] - TableScan: nation projection=[n_nationkey, n_name, n_regionkey] - Filter: region.r_name = Utf8("EUROPE") - TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] - Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1 - Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]] - Inner Join: nation.n_regionkey = region.r_regionkey - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: partsupp.ps_suppkey = supplier.s_suppkey + Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, nation.n_name + Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value + Inner Join: nation.n_regionkey = region.r_regionkey + Inner Join: supplier.s_nationkey = nation.n_nationkey + Inner Join: partsupp.ps_suppkey = supplier.s_suppkey + Inner Join: part.p_partkey = partsupp.ps_partkey + Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") + TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] - TableScan: nation projection=[n_nationkey, n_name, n_regionkey] - Filter: region.r_name = Utf8("EUROPE") - TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]"# + TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] + TableScan: nation projection=[n_nationkey, n_name, n_regionkey] + Filter: region.r_name = Utf8("EUROPE") + TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] + Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1 + Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]] + Inner Join: nation.n_regionkey = region.r_regionkey + Inner Join: supplier.s_nationkey = nation.n_nationkey + Inner Join: partsupp.ps_suppkey = supplier.s_suppkey + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] + TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] + TableScan: nation projection=[n_nationkey, n_name, n_regionkey] + Filter: region.r_name = Utf8("EUROPE") + TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]"# .to_string(); assert_eq!(actual, expected); diff --git a/datafusion/optimizer/src/reduce_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs similarity index 77% rename from datafusion/optimizer/src/reduce_cross_join.rs rename to datafusion/optimizer/src/eliminate_cross_join.rs index 45230ebb243b..23e80ee542a9 100644 --- a/datafusion/optimizer/src/reduce_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -16,19 +16,19 @@ // under the License. //! Optimizer rule to reduce cross join to inner join if join predicates are available in filters. -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, Result}; +use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_common::{Column, DFSchema, DataFusionError, Result}; use datafusion_expr::{ and, expr::BinaryExpr, logical_plan::{CrossJoin, Filter, Join, JoinType, LogicalPlan}, or, utils::can_hash, - utils::from_plan, + Projection, }; use datafusion_expr::{Expr, Operator}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; //use std::collections::HashMap; use datafusion_expr::logical_plan::JoinConstraint; @@ -44,16 +44,92 @@ impl ReduceCrossJoin { } } +/// Attempt to reorder join tp reduce cross joins to inner joins. +/// for queries: +/// 'select ... from a, b where a.x = b.y and b.xx = 100;' +/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' +/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) +/// or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// For above queries, the join predicate is available in filters and they are moved to +/// join nodes appropriately +/// This fix helps to improve the performance of TPCH Q19. issue#78 +/// impl OptimizerRule for ReduceCrossJoin { fn optimize( &self, plan: &LogicalPlan, _optimizer_config: &mut OptimizerConfig, ) -> Result { - let mut possible_join_keys: Vec<(Column, Column)> = vec![]; - let mut all_join_keys = HashSet::new(); + match plan { + LogicalPlan::Filter(filter) => { + let input = (**filter.input()).clone(); + + let mut possible_join_keys: Vec<(Column, Column)> = vec![]; + let mut all_inputs: Vec = vec![]; + match &input { + LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => { + flatten_join_inputs( + &input, + &mut possible_join_keys, + &mut all_inputs, + )?; + } + LogicalPlan::CrossJoin(_) => { + flatten_join_inputs( + &input, + &mut possible_join_keys, + &mut all_inputs, + )?; + } + _ => { + return utils::optimize_children(self, plan, _optimizer_config); + } + } + + let predicate = filter.predicate(); + // join keys are handled locally + let mut all_join_keys: HashSet<(Column, Column)> = HashSet::new(); + + extract_possible_join_keys(predicate, &mut possible_join_keys); + + let mut left = all_inputs.remove(0); + while !all_inputs.is_empty() { + left = find_inner_join( + &left, + &mut all_inputs, + &mut possible_join_keys, + &mut all_join_keys, + )?; + } - reduce_cross_join(self, plan, &mut possible_join_keys, &mut all_join_keys) + left = utils::optimize_children(self, &left, _optimizer_config)?; + if plan.schema() != left.schema() { + left = LogicalPlan::Projection(Projection::new_from_schema( + Arc::new(left.clone()), + plan.schema().clone(), + )); + } + + // if there are no join keys then do nothing. + if all_join_keys.is_empty() { + Ok(LogicalPlan::Filter(Filter::try_new( + predicate.clone(), + Arc::new(left), + )?)) + } else { + // remove join expressions from filter + match remove_join_expressions(predicate, &all_join_keys)? { + Some(filter_expr) => Ok(LogicalPlan::Filter(Filter::try_new( + filter_expr, + Arc::new(left), + )?)), + _ => Ok(left), + } + } + } + + _ => utils::optimize_children(self, plan, _optimizer_config), + } } fn name(&self) -> &str { @@ -61,126 +137,108 @@ impl OptimizerRule for ReduceCrossJoin { } } -/// Attempt to reduce cross joins to inner joins. -/// for queries: -/// 'select ... from a, b where a.x = b.y and b.xx = 100;' -/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' -/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) -/// or (a.x = b.y and b.xx = 200 and a.z=c.z);' -/// For above queries, the join predicate is available in filters and they are moved to -/// join nodes appropriately -/// This fix helps to improve the performance of TPCH Q19. issue#78 -/// -fn reduce_cross_join( - _optimizer: &ReduceCrossJoin, +fn flatten_join_inputs( plan: &LogicalPlan, possible_join_keys: &mut Vec<(Column, Column)>, - all_join_keys: &mut HashSet<(Column, Column)>, -) -> Result { - match plan { - LogicalPlan::Filter(filter) => { - let input = filter.input(); - let predicate = filter.predicate(); - // join keys are handled locally - let mut new_possible_join_keys: Vec<(Column, Column)> = vec![]; - let mut new_all_join_keys = HashSet::new(); - - extract_possible_join_keys(predicate, &mut new_possible_join_keys); - - let new_plan = reduce_cross_join( - _optimizer, - input, - &mut new_possible_join_keys, - &mut new_all_join_keys, - )?; - - // if there are no join keys then do nothing. - if new_all_join_keys.is_empty() { - Ok(LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(new_plan), - )?)) - } else { - // remove join expressions from filter - match remove_join_expressions(predicate, &new_all_join_keys)? { - Some(filter_expr) => Ok(LogicalPlan::Filter(Filter::try_new( - filter_expr, - Arc::new(new_plan), - )?)), - _ => Ok(new_plan), - } + all_inputs: &mut Vec, +) -> Result<()> { + let children = match plan { + LogicalPlan::Join(join) => { + for join_keys in join.on.iter() { + possible_join_keys.push(join_keys.clone()); } + let left = &*(join.left); + let right = &*(join.right); + Ok::, DataFusionError>(vec![left, right]) } - LogicalPlan::CrossJoin(cross_join) => { - let left_plan = reduce_cross_join( - _optimizer, - &cross_join.left, - possible_join_keys, - all_join_keys, - )?; - let right_plan = reduce_cross_join( - _optimizer, - &cross_join.right, - possible_join_keys, - all_join_keys, - )?; - // can we find a match? - let left_schema = left_plan.schema(); - let right_schema = right_plan.schema(); - let mut join_keys = vec![]; - - for (l, r) in possible_join_keys { - if left_schema.field_from_column(l).is_ok() - && right_schema.field_from_column(r).is_ok() - && can_hash(left_schema.field_from_column(l).unwrap().data_type()) - { - join_keys.push((l.clone(), r.clone())); - } else if left_schema.field_from_column(r).is_ok() - && right_schema.field_from_column(l).is_ok() - && can_hash(left_schema.field_from_column(r).unwrap().data_type()) - { - join_keys.push((r.clone(), l.clone())); + LogicalPlan::CrossJoin(join) => { + let left = &*(join.left); + let right = &*(join.right); + Ok::, DataFusionError>(vec![left, right]) + } + _ => { + return Err(DataFusionError::Plan( + "flatten_join_inputs just can call join/cross_join".to_string(), + )); + } + }?; + + for child in children.iter() { + match *child { + LogicalPlan::Join(left_join) => { + if left_join.join_type == JoinType::Inner { + flatten_join_inputs(child, possible_join_keys, all_inputs)?; + } else { + all_inputs.push((*child).clone()); } } + LogicalPlan::CrossJoin(_) => { + flatten_join_inputs(child, possible_join_keys, all_inputs)?; + } + _ => all_inputs.push((*child).clone()), + } + } + Ok(()) +} - // if there are no join keys then do nothing. - if join_keys.is_empty() { - Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left_plan), - right: Arc::new(right_plan), - schema: cross_join.schema.clone(), - })) - } else { - // Keep track of join keys being pushed to Join nodes - all_join_keys.extend(join_keys.clone()); - - Ok(LogicalPlan::Join(Join { - left: Arc::new(left_plan), - right: Arc::new(right_plan), - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - on: join_keys, - filter: None, - schema: cross_join.schema.clone(), - null_equals_null: false, - })) +fn find_inner_join( + left: &LogicalPlan, + rights: &mut Vec, + possible_join_keys: &mut Vec<(Column, Column)>, + all_join_keys: &mut HashSet<(Column, Column)>, +) -> Result { + for (i, right) in rights.iter().enumerate() { + let mut join_keys = vec![]; + + for (l, r) in &mut *possible_join_keys { + if left.schema().field_from_column(l).is_ok() + && right.schema().field_from_column(r).is_ok() + && can_hash(left.schema().field_from_column(l).unwrap().data_type()) + { + join_keys.push((l.clone(), r.clone())); + } else if left.schema().field_from_column(r).is_ok() + && right.schema().field_from_column(l).is_ok() + && can_hash(left.schema().field_from_column(r).unwrap().data_type()) + { + join_keys.push((r.clone(), l.clone())); } } - _ => { - let expr = plan.expressions(); - - // apply the optimization to all inputs of the plan - let inputs = plan.inputs(); - let new_inputs = inputs - .iter() - .map(|plan| { - reduce_cross_join(_optimizer, plan, possible_join_keys, all_join_keys) - }) - .collect::>>()?; - - from_plan(plan, &expr, &new_inputs) + + if !join_keys.is_empty() { + all_join_keys.extend(join_keys.clone()); + let right = rights.remove(i); + let join_schema = Arc::new(build_join_schema(left, &right)?); + return Ok(LogicalPlan::Join(Join { + left: Arc::new(left.clone()), + right: Arc::new(right), + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + on: join_keys, + filter: None, + schema: join_schema, + null_equals_null: false, + })); } } + let right = rights.remove(0); + let join_schema = Arc::new(build_join_schema(left, &right)?); + + Ok(LogicalPlan::CrossJoin(CrossJoin { + left: Arc::new(left.clone()), + right: Arc::new(right), + schema: join_schema, + })) +} + +fn build_join_schema(left: &LogicalPlan, right: &LogicalPlan) -> Result { + // build join schema + let mut fields = vec![]; + let mut metadata = HashMap::new(); + fields.extend(left.schema().fields().clone()); + fields.extend(right.schema().fields().clone()); + metadata.extend(left.schema().metadata().clone()); + metadata.extend(right.schema().metadata().clone()); + DFSchema::new_with_metadata(fields, metadata) } fn intersect( @@ -475,6 +533,53 @@ mod tests { Ok(()) } + #[test] + /// ```txt + /// filter: a.id = b.id and a.id = c.id + /// cross_join a (bc) + /// cross_join b c + /// ``` + /// Without reorder, it will be + /// ```txt + /// inner_join a (bc) on a.id = b.id and a.id = c.id + /// cross_join b c + /// ``` + /// Reorder it to be + /// ```txt + /// inner_join (ab)c and a.id = c.id + /// inner_join a b on a.id = b.id + /// ``` + fn reorder_join_to_reduce_cross_join_multi_tables() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + let t3 = test_table_scan_with_name("t3")?; + + // could reduce to inner join + let plan = LogicalPlanBuilder::from(t1) + .cross_join(&t2)? + .cross_join(&t3)? + .filter(binary_expr( + binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))), + And, + binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))), + ))? + .build()?; + + let expected = vec![ + "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + #[test] fn reduce_cross_join_multi_tables() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; @@ -849,14 +954,14 @@ mod tests { let expected = vec![ "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -937,13 +1042,13 @@ mod tests { let expected = vec![ "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index e62cbbd73103..467ec3b2408e 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -18,6 +18,7 @@ pub mod common_subexpr_eliminate; pub mod decorrelate_where_exists; pub mod decorrelate_where_in; +pub mod eliminate_cross_join; pub mod eliminate_filter; pub mod eliminate_limit; pub mod filter_null_join_keys; @@ -27,7 +28,6 @@ pub mod limit_push_down; pub mod optimizer; pub mod projection_push_down; pub mod propagate_empty_relation; -pub mod reduce_cross_join; pub mod reduce_outer_join; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index f09d2ee24cb1..3614c8a4f49f 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -20,6 +20,7 @@ use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_where_exists::DecorrelateWhereExists; use crate::decorrelate_where_in::DecorrelateWhereIn; +use crate::eliminate_cross_join::ReduceCrossJoin; use crate::eliminate_filter::EliminateFilter; use crate::eliminate_limit::EliminateLimit; use crate::filter_null_join_keys::FilterNullJoinKeys; @@ -28,7 +29,6 @@ use crate::inline_table_scan::InlineTableScan; use crate::limit_push_down::LimitPushDown; use crate::projection_push_down::ProjectionPushDown; use crate::propagate_empty_relation::PropagateEmptyRelation; -use crate::reduce_cross_join::ReduceCrossJoin; use crate::reduce_outer_join::ReduceOuterJoin; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index b2932963edd5..fb27ed5edf44 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -267,6 +267,42 @@ fn propagate_empty_relation() { assert_eq!(expected, format!("{:?}", plan)); } +#[test] +fn join_keys_in_subquery_alias() { + let sql = "SELECT * FROM test AS A, ( SELECT col_int32 as key FROM test ) AS B where A.col_int32 = B.key;"; + let plan = test_sql(sql).unwrap(); + let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\ + \n Inner Join: a.col_int32 = b.key\ + \n Filter: a.col_int32 IS NOT NULL\ + \n SubqueryAlias: a\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\ + \n Projection: key, alias=b\ + \n Projection: test.col_int32 AS key\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{:?}", plan)); +} + +#[test] +fn join_keys_in_subquery_alias_1() { + let sql = "SELECT * FROM test AS A, ( SELECT test.col_int32 AS key FROM test JOIN test AS C on test.col_int32 = C.col_int32 ) AS B where A.col_int32 = B.key;"; + let plan = test_sql(sql).unwrap(); + let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\ + \n Inner Join: a.col_int32 = b.key\ + \n Filter: a.col_int32 IS NOT NULL\ + \n SubqueryAlias: a\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\ + \n Projection: key, alias=b\ + \n Projection: test.col_int32 AS key\ + \n Inner Join: test.col_int32 = c.col_int32\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]\ + \n Filter: c.col_int32 IS NOT NULL\ + \n SubqueryAlias: c\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{:?}", plan)); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 8e3e6d91180d..1e09054c86c8 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -57,12 +57,12 @@ use datafusion_expr::logical_plan::{ use datafusion_expr::logical_plan::{Filter, Subquery}; use datafusion_expr::utils::{ can_hash, check_all_column_from_schema, expand_qualified_wildcard, expand_wildcard, - expr_as_column_expr, find_aggregate_exprs, find_column_exprs, find_window_exprs, - COUNT_STAR_EXPANSION, + expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_column_exprs, + find_window_exprs, COUNT_STAR_EXPANSION, }; use datafusion_expr::Expr::Alias; use datafusion_expr::{ - and, cast, col, lit, AggregateFunction, AggregateUDF, Expr, ExprSchemable, + cast, col, lit, AggregateFunction, AggregateUDF, Expr, ExprSchemable, GetIndexedField, Operator, ScalarUDF, WindowFrame, WindowFrameUnits, }; use datafusion_expr::{ @@ -948,166 +948,52 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { outer_query_schema: Option<&DFSchema>, ctes: &mut HashMap, ) -> Result { + let cross_join_plan = if plans.len() == 1 { + plans[0].clone() + } else { + let mut left = plans[0].clone(); + for right in plans.iter().skip(1) { + left = LogicalPlanBuilder::from(left).cross_join(right)?.build()?; + } + left + }; match selection { Some(predicate_expr) => { - // build join schema let mut fields = vec![]; - let mut metadata = std::collections::HashMap::new(); + let mut metadata = HashMap::new(); for plan in &plans { fields.extend_from_slice(plan.schema().fields()); metadata.extend(plan.schema().metadata().clone()); } + let mut join_schema = DFSchema::new_with_metadata(fields, metadata)?; + let mut all_schemas: Vec = vec![]; + for plan in plans { + for schema in plan.all_schemas() { + all_schemas.push(schema.clone()); + } + } if let Some(outer) = outer_query_schema { + all_schemas.push(Arc::new(outer.clone())); join_schema.merge(outer); } + let x: Vec<&DFSchemaRef> = all_schemas.iter().collect(); let filter_expr = self.sql_to_rex(predicate_expr, &join_schema, ctes)?; + let mut using_columns = HashSet::new(); + expr_to_columns(&filter_expr, &mut using_columns)?; + let filter_expr = normalize_col_with_schemas( + filter_expr, + x.as_slice(), + &[using_columns], + )?; - // look for expressions of the form ` = ` - let mut possible_join_keys = vec![]; - extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?; - - let mut all_join_keys = HashSet::new(); - - let orig_plans = plans.clone(); - let mut plans = plans.into_iter(); - let mut left = plans.next().unwrap(); // have at least one plan - - // List of the plans that have not yet been joined - let mut remaining_plans: Vec> = - plans.into_iter().map(Some).collect(); - - // Take from the list of remaining plans, - loop { - let mut join_keys = vec![]; - - // Search all remaining plans for the next to - // join. Prefer the first one that has a join - // predicate in the predicate lists - let plan_with_idx = - remaining_plans.iter().enumerate().find(|(_idx, plan)| { - // skip plans that have been joined already - let plan = if let Some(plan) = plan { - plan - } else { - return false; - }; - - // can we find a match? - let left_schema = left.schema(); - let right_schema = plan.schema(); - for (l, r) in &possible_join_keys { - if left_schema.field_from_column(l).is_ok() - && right_schema.field_from_column(r).is_ok() - && can_hash( - left_schema - .field_from_column(l) - .unwrap() // the result must be OK - .data_type(), - ) - { - join_keys.push((l.clone(), r.clone())); - } else if left_schema.field_from_column(r).is_ok() - && right_schema.field_from_column(l).is_ok() - && can_hash( - left_schema - .field_from_column(r) - .unwrap() // the result must be OK - .data_type(), - ) - { - join_keys.push((r.clone(), l.clone())); - } - } - // stop if we found join keys - !join_keys.is_empty() - }); - - // If we did not find join keys, either there are - // no more plans, or we can't find any plans that - // can be joined with predicates - if join_keys.is_empty() { - assert!(plan_with_idx.is_none()); - - // pick the first non null plan to join - let plan_with_idx = remaining_plans - .iter() - .enumerate() - .find(|(_idx, plan)| plan.is_some()); - if let Some((idx, _)) = plan_with_idx { - let plan = std::mem::take(&mut remaining_plans[idx]).unwrap(); - left = LogicalPlanBuilder::from(left) - .cross_join(&plan)? - .build()?; - } else { - // no more plans to join - break; - } - } else { - // have a plan - let (idx, _) = plan_with_idx.expect("found plan node"); - let plan = std::mem::take(&mut remaining_plans[idx]).unwrap(); - - let left_keys: Vec = - join_keys.iter().map(|(l, _)| l.clone()).collect(); - let right_keys: Vec = - join_keys.iter().map(|(_, r)| r.clone()).collect(); - let builder = LogicalPlanBuilder::from(left); - left = builder - .join(&plan, JoinType::Inner, (left_keys, right_keys), None)? - .build()?; - } - - all_join_keys.extend(join_keys); - } - - // remove join expressions from filter - match remove_join_expressions(&filter_expr, &all_join_keys)? { - Some(filter_expr) => { - // this logic is adapted from [`LogicalPlanBuilder::filter`] to take - // the query outer schema into account so that joins in subqueries - // can reference outer query fields. - let mut all_schemas: Vec = vec![]; - for plan in orig_plans { - for schema in plan.all_schemas() { - all_schemas.push(schema.clone()); - } - } - if let Some(outer_query_schema) = outer_query_schema { - all_schemas.push(Arc::new(outer_query_schema.clone())); - } - let mut join_columns = HashSet::new(); - for (l, r) in &all_join_keys { - join_columns.insert(l.clone()); - join_columns.insert(r.clone()); - } - let x: Vec<&DFSchemaRef> = all_schemas.iter().collect(); - let filter_expr = normalize_col_with_schemas( - filter_expr, - x.as_slice(), - &[join_columns], - )?; - Ok(LogicalPlan::Filter(Filter::try_new( - filter_expr, - Arc::new(left), - )?)) - } - _ => Ok(left), - } - } - None => { - if plans.len() == 1 { - Ok(plans[0].clone()) - } else { - let mut left = plans[0].clone(); - for right in plans.iter().skip(1) { - left = - LogicalPlanBuilder::from(left).cross_join(right)?.build()?; - } - Ok(left) - } + Ok(LogicalPlan::Filter(Filter::try_new( + filter_expr, + Arc::new(cross_join_plan), + )?)) } + None => Ok(cross_join_plan), } } @@ -2707,7 +2593,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | Value::Null | Value::Placeholder(_) => { return Err(DataFusionError::Plan(format!( - "Unspported Value {}", + "Unsupported Value {}", value[0] ))) } @@ -2718,14 +2604,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { UnaryOperator::Minus => format!("-{}", expr), _ => { return Err(DataFusionError::Plan(format!( - "Unspported Value {}", + "Unsupported Value {}", value[0] ))) } }, _ => { return Err(DataFusionError::Plan(format!( - "Unspported Value {}", + "Unsupported Value {}", value[0] ))) } @@ -3054,41 +2940,6 @@ pub fn object_name_to_qualifier(sql_table_name: &ObjectName) -> String { .join(" AND ") } -/// Remove join expressions from a filter expression -fn remove_join_expressions( - expr: &Expr, - join_columns: &HashSet<(Column, Column)>, -) -> Result> { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - if join_columns.contains(&(l.clone(), r.clone())) - || join_columns.contains(&(r.clone(), l.clone())) - { - Ok(None) - } else { - Ok(Some(expr.clone())) - } - } - _ => Ok(Some(expr.clone())), - }, - Operator::And => { - let l = remove_join_expressions(left, join_columns)?; - let r = remove_join_expressions(right, join_columns)?; - match (l, r) { - (Some(ll), Some(rr)) => Ok(Some(and(ll, rr))), - (Some(ll), _) => Ok(Some(ll)), - (_, Some(rr)) => Ok(Some(rr)), - _ => Ok(None), - } - } - _ => Ok(Some(expr.clone())), - }, - _ => Ok(Some(expr.clone())), - } -} - /// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs /// Filters matching this pattern are added to `accum` /// Filters that don't match this pattern are added to `accum_filter` @@ -3196,30 +3047,6 @@ fn extract_join_keys( Ok(()) } -/// Extract join keys from a WHERE clause -fn extract_possible_join_keys( - expr: &Expr, - accum: &mut Vec<(Column, Column)>, -) -> Result<()> { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - accum.push((l.clone(), r.clone())); - Ok(()) - } - _ => Ok(()), - }, - Operator::And => { - extract_possible_join_keys(left, accum)?; - extract_possible_join_keys(right, accum) - } - _ => Ok(()), - }, - _ => Ok(()), - } -} - /// Wrap projection for a plan, if the join keys contains normal expression. fn wrap_projection_for_join_if_necessary( join_keys: &[Expr], @@ -5486,18 +5313,6 @@ mod tests { quick_test(sql, expected); } - #[test] - fn cross_join_to_inner_join() { - let sql = "select person.id from person, orders, lineitem where person.id = lineitem.l_item_id and orders.o_item_id = lineitem.l_description;"; - let expected = "Projection: person.id\ - \n Inner Join: lineitem.l_description = orders.o_item_id\ - \n Inner Join: person.id = lineitem.l_item_id\ - \n TableScan: person\ - \n TableScan: lineitem\ - \n TableScan: orders"; - quick_test(sql, expected); - } - #[test] fn cross_join_not_to_inner_join() { let sql = "select person.id from person, orders, lineitem where person.id = person.age;"; @@ -5581,15 +5396,15 @@ mod tests { AND person.state = p.state)"; let expected = "Projection: person.id\ - \n Filter: EXISTS ()\ + \n Filter: person.id = p.id AND EXISTS ()\ \n Subquery:\ \n Projection: person.first_name\ - \n Filter: person.last_name = p.last_name AND person.state = p.state\ - \n Inner Join: person.id = p2.id\ + \n Filter: person.id = p2.id AND person.last_name = p.last_name AND person.state = p.state\ + \n CrossJoin:\ \n TableScan: person\ \n SubqueryAlias: p2\ \n TableScan: person\ - \n Inner Join: person.id = p.id\ + \n CrossJoin:\ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -5675,8 +5490,8 @@ mod tests { \n Subquery:\ \n Projection: COUNT(UInt8(1))\ \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n Filter: j2.j2_id = j1.j1_id\ - \n Inner Join: j1.j1_id = j3.j3_id\ + \n Filter: j2.j2_id = j1.j1_id AND j1.j1_id = j3.j3_id\ + \n CrossJoin:\ \n TableScan: j1\ \n TableScan: j3\ \n CrossJoin:\ @@ -6090,8 +5905,8 @@ mod tests { #[test] fn test_select_join_key_inner_join() { let sql = "SELECT orders.customer_id * 2, person.id + 10 - FROM person - INNER JOIN orders + FROM person + INNER JOIN orders ON orders.customer_id * 2 = person.id + 10"; let expected = "Projection: orders.customer_id * Int64(2), person.id + Int64(10)\ @@ -6107,9 +5922,9 @@ mod tests { #[test] fn test_non_projetion_after_inner_join() { // There's no need to add projection for left and right, so does adding projection after join. - let sql = "SELECT person.id, person.age - FROM person - INNER JOIN orders + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders ON orders.customer_id = person.id"; let expected = "Projection: person.id, person.age\ @@ -6122,9 +5937,9 @@ mod tests { #[test] fn test_duplicated_left_join_key_inner_join() { // person.id * 2 happen twice in left side. - let sql = "SELECT person.id, person.age - FROM person - INNER JOIN orders + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders ON person.id * 2 = orders.customer_id + 10 and person.id * 2 = orders.order_id"; let expected = "Projection: person.id, person.age\ @@ -6140,9 +5955,9 @@ mod tests { #[test] fn test_duplicated_right_join_key_inner_join() { // orders.customer_id + 10 happen twice in right side. - let sql = "SELECT person.id, person.age - FROM person - INNER JOIN orders + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders ON person.id * 2 = orders.customer_id + 10 and person.id = orders.customer_id + 10"; let expected = "Projection: person.id, person.age\