From 735803540f3c8d8e7a8e4a5c4815d1dd0b170b6d Mon Sep 17 00:00:00 2001 From: kmitchener Date: Mon, 29 Aug 2022 14:09:23 -0400 Subject: [PATCH 1/4] initial commit for #3266 --- .../src/decorrelate_scalar_subquery.rs | 68 +++++++++++++------ datafusion/optimizer/src/utils.rs | 4 +- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs index 561757dc8745..d4d0fc25ba85 100644 --- a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs @@ -56,8 +56,8 @@ impl DecorrelateScalarSubquery { for it in filters.iter() { match it { Expr::BinaryExpr { left, op, right } => { - let l_query = Subquery::try_from_expr(left); - let r_query = Subquery::try_from_expr(right); + let l_query = Subquery::try_from_expr(&left); + let r_query = Subquery::try_from_expr(&right); if l_query.is_err() && r_query.is_err() { others.push((*it).clone()); continue; @@ -101,16 +101,17 @@ impl OptimizerRule for DecorrelateScalarSubquery { let (subqueries, other_exprs) = self.extract_subquery_exprs(predicate, optimizer_config)?; - let optimized_plan = LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - }); + if subqueries.is_empty() { // regular filter, no subquery exists clause here + let optimized_plan = LogicalPlan::Filter(Filter { + predicate: predicate.clone(), + input: Arc::new(optimized_input), + }); return Ok(optimized_plan); } - // iterate through all exists clauses in predicate, turning each into a join + // iterate through all subqueries in predicate, turning each into a join let mut cur_input = (**input).clone(); for subquery in subqueries { cur_input = optimize_scalar( @@ -136,22 +137,39 @@ impl OptimizerRule for DecorrelateScalarSubquery { /// Takes a query like: /// -/// ```select id from customers where balance > +/// ```text +/// select id from customers where balance > /// (select avg(total) from orders where orders.c_id = customers.id) /// ``` /// /// and optimizes it into: /// -/// ```select c.id from customers c +/// ```text +/// select c.id from customers c /// inner join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id -/// where c.balance > o.val``` +/// where c.balance > o.val +/// ``` +/// +/// Or a query like: +/// +/// ```text +/// select id from customers where balance > +/// (select avg(total) from orders) +/// ``` +/// +/// and optimizes it into: +/// +/// ```text +/// select c.id from customers c +/// cross join (select avg(total) as val from orders) a +/// where c.balance > a.val +/// ``` /// /// # Arguments /// -/// * `subqry` - The subquery portion of the `where exists` (select * from orders) -/// * `negated` - True if the subquery is a `where not exists` +/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders) /// * `filter_input` - The non-subquery portion (from customers) -/// * `other_filter_exprs` - Any additional parts to the `where` expression (and c.x = y) +/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) /// * `optimizer_config` - Used to generate unique subquery aliases fn optimize_scalar( query_info: &SubqueryInfo, @@ -173,20 +191,28 @@ fn optimize_scalar( .map_err(|e| context!("Exactly one input is expected. Is this a join?", e))?; let aggr = Aggregate::try_from_plan(sub_input) .map_err(|e| context!("scalar subqueries must aggregate a value", e))?; - let filter = Filter::try_from_plan(&aggr.input).map_err(|e| { - context!("scalar subqueries must have a filter to be correlated", e) - })?; + let filter = Filter::try_from_plan(&aggr.input).ok(); + + // if there were filters, we use that logical plan, otherwise the plan from the aggregate + let input: &LogicalPlan; + if filter.is_some() { + input = &filter.unwrap().input; + } else { + input = &aggr.input; + }; - // split into filters + // if there were filters, split and capture them let mut subqry_filter_exprs = vec![]; - split_conjunction(&filter.predicate, &mut subqry_filter_exprs); + if filter.is_some() { + split_conjunction(&filter.unwrap().predicate, &mut subqry_filter_exprs); + } verify_not_disjunction(&subqry_filter_exprs)?; // Grab column names to join on let (col_exprs, other_subqry_exprs) = - find_join_exprs(subqry_filter_exprs, filter.input.schema())?; + find_join_exprs(subqry_filter_exprs, *&input.schema())?; let (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, filter.input.schema(), false)?; + exprs_to_join_cols(&col_exprs, *&input.schema(), false)?; if join_filters.is_some() { plan_err!("only joins on column equality are presently supported")?; } @@ -199,7 +225,7 @@ fn optimize_scalar( .collect(); // build subquery side of join - the thing the subquery was querying - let mut subqry_plan = LogicalPlanBuilder::from((*filter.input).clone()); + let mut subqry_plan = LogicalPlanBuilder::from((*input).clone()); if let Some(expr) = combine_filters(&other_subqry_exprs) { subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 41c75d689f5d..d962dd7b45b9 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -125,7 +125,7 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { /// # Arguments /// /// * `exprs` - List of expressions that may or may not be joins -/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema +/// * `schema` - HashSet of fully qualified (table.col) fields in subquery schema /// /// # Return value /// @@ -191,7 +191,7 @@ pub fn find_join_exprs( /// # Arguments /// /// * `exprs` - List of expressions that correlate a subquery to an outer scope -/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema +/// * `schema` - subquery schema /// * `include_negated` - true if `NotEq` counts as a join operator /// /// # Return value From 60607be10ceb9a2d9afd322774172dae4426b1e8 Mon Sep 17 00:00:00 2001 From: kmitchener Date: Mon, 29 Aug 2022 14:33:35 -0400 Subject: [PATCH 2/4] added test --- .../src/decorrelate_scalar_subquery.rs | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs index d4d0fc25ba85..ba811aaa1e2e 100644 --- a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs @@ -728,4 +728,34 @@ mod tests { assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); Ok(()) } + + /// Test for non-correlated scalar subquery with no filters + #[test] + fn scalar_subquery_non_correlated_no_filters() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter( + col("customer.c_custkey") + .eq(scalar_subquery(sq)), + )? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N] + CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #MAX(orders.o_custkey) AS __value, alias=__sq_1 [__value:Int64;N] + Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } } From ab0f9ae4bd65299220e7702cb27b319afe52c22d Mon Sep 17 00:00:00 2001 From: kmitchener Date: Mon, 29 Aug 2022 14:40:57 -0400 Subject: [PATCH 3/4] lint strikes again --- datafusion/optimizer/src/decorrelate_scalar_subquery.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs index ba811aaa1e2e..d08d99f4463d 100644 --- a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs @@ -740,10 +740,7 @@ mod tests { ); let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) - .filter( - col("customer.c_custkey") - .eq(scalar_subquery(sq)), - )? + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? .project(vec![col("customer.c_custkey")])? .build()?; From f395f776242d145a11f19dea393ca27c305bd3eb Mon Sep 17 00:00:00 2001 From: kmitchener Date: Mon, 29 Aug 2022 15:05:55 -0400 Subject: [PATCH 4/4] clippy --- .../src/decorrelate_scalar_subquery.rs | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs index d08d99f4463d..1d6e5d5338cb 100644 --- a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs @@ -56,8 +56,8 @@ impl DecorrelateScalarSubquery { for it in filters.iter() { match it { Expr::BinaryExpr { left, op, right } => { - let l_query = Subquery::try_from_expr(&left); - let r_query = Subquery::try_from_expr(&right); + let l_query = Subquery::try_from_expr(left); + let r_query = Subquery::try_from_expr(right); if l_query.is_err() && r_query.is_err() { others.push((*it).clone()); continue; @@ -194,25 +194,24 @@ fn optimize_scalar( let filter = Filter::try_from_plan(&aggr.input).ok(); // if there were filters, we use that logical plan, otherwise the plan from the aggregate - let input: &LogicalPlan; - if filter.is_some() { - input = &filter.unwrap().input; + let input = if let Some(filter) = filter { + &filter.input } else { - input = &aggr.input; + &aggr.input }; // if there were filters, split and capture them let mut subqry_filter_exprs = vec![]; - if filter.is_some() { - split_conjunction(&filter.unwrap().predicate, &mut subqry_filter_exprs); + if let Some(filter) = filter { + split_conjunction(&filter.predicate, &mut subqry_filter_exprs); } verify_not_disjunction(&subqry_filter_exprs)?; // Grab column names to join on let (col_exprs, other_subqry_exprs) = - find_join_exprs(subqry_filter_exprs, *&input.schema())?; + find_join_exprs(subqry_filter_exprs, input.schema())?; let (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, *&input.schema(), false)?; + exprs_to_join_cols(&col_exprs, input.schema(), false)?; if join_filters.is_some() { plan_err!("only joins on column equality are presently supported")?; } @@ -225,7 +224,7 @@ fn optimize_scalar( .collect(); // build subquery side of join - the thing the subquery was querying - let mut subqry_plan = LogicalPlanBuilder::from((*input).clone()); + let mut subqry_plan = LogicalPlanBuilder::from((**input).clone()); if let Some(expr) = combine_filters(&other_subqry_exprs) { subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them }