diff --git a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs index 561757dc8745..1d6e5d5338cb 100644 --- a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs @@ -101,16 +101,17 @@ impl OptimizerRule for DecorrelateScalarSubquery { let (subqueries, other_exprs) = self.extract_subquery_exprs(predicate, optimizer_config)?; - let optimized_plan = LogicalPlan::Filter(Filter { - predicate: predicate.clone(), - input: Arc::new(optimized_input), - }); + if subqueries.is_empty() { // regular filter, no subquery exists clause here + let optimized_plan = LogicalPlan::Filter(Filter { + predicate: predicate.clone(), + input: Arc::new(optimized_input), + }); return Ok(optimized_plan); } - // iterate through all exists clauses in predicate, turning each into a join + // iterate through all subqueries in predicate, turning each into a join let mut cur_input = (**input).clone(); for subquery in subqueries { cur_input = optimize_scalar( @@ -136,22 +137,39 @@ impl OptimizerRule for DecorrelateScalarSubquery { /// Takes a query like: /// -/// ```select id from customers where balance > +/// ```text +/// select id from customers where balance > /// (select avg(total) from orders where orders.c_id = customers.id) /// ``` /// /// and optimizes it into: /// -/// ```select c.id from customers c +/// ```text +/// select c.id from customers c /// inner join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id -/// where c.balance > o.val``` +/// where c.balance > o.val +/// ``` +/// +/// Or a query like: +/// +/// ```text +/// select id from customers where balance > +/// (select avg(total) from orders) +/// ``` +/// +/// and optimizes it into: +/// +/// ```text +/// select c.id from customers c +/// cross join (select avg(total) as val from orders) a +/// where c.balance > a.val +/// ``` /// /// # Arguments /// -/// * `subqry` - The subquery portion of the `where exists` (select * from orders) -/// * `negated` - True if the subquery is a `where not exists` +/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders) /// * `filter_input` - The non-subquery portion (from customers) -/// * `other_filter_exprs` - Any additional parts to the `where` expression (and c.x = y) +/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) /// * `optimizer_config` - Used to generate unique subquery aliases fn optimize_scalar( query_info: &SubqueryInfo, @@ -173,20 +191,27 @@ fn optimize_scalar( .map_err(|e| context!("Exactly one input is expected. Is this a join?", e))?; let aggr = Aggregate::try_from_plan(sub_input) .map_err(|e| context!("scalar subqueries must aggregate a value", e))?; - let filter = Filter::try_from_plan(&aggr.input).map_err(|e| { - context!("scalar subqueries must have a filter to be correlated", e) - })?; + let filter = Filter::try_from_plan(&aggr.input).ok(); - // split into filters + // if there were filters, we use that logical plan, otherwise the plan from the aggregate + let input = if let Some(filter) = filter { + &filter.input + } else { + &aggr.input + }; + + // if there were filters, split and capture them let mut subqry_filter_exprs = vec![]; - split_conjunction(&filter.predicate, &mut subqry_filter_exprs); + if let Some(filter) = filter { + split_conjunction(&filter.predicate, &mut subqry_filter_exprs); + } verify_not_disjunction(&subqry_filter_exprs)?; // Grab column names to join on let (col_exprs, other_subqry_exprs) = - find_join_exprs(subqry_filter_exprs, filter.input.schema())?; + find_join_exprs(subqry_filter_exprs, input.schema())?; let (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, filter.input.schema(), false)?; + exprs_to_join_cols(&col_exprs, input.schema(), false)?; if join_filters.is_some() { plan_err!("only joins on column equality are presently supported")?; } @@ -199,7 +224,7 @@ fn optimize_scalar( .collect(); // build subquery side of join - the thing the subquery was querying - let mut subqry_plan = LogicalPlanBuilder::from((*filter.input).clone()); + let mut subqry_plan = LogicalPlanBuilder::from((**input).clone()); if let Some(expr) = combine_filters(&other_subqry_exprs) { subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them } @@ -702,4 +727,31 @@ mod tests { assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); Ok(()) } + + /// Test for non-correlated scalar subquery with no filters + #[test] + fn scalar_subquery_non_correlated_no_filters() -> Result<()> { + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64] + Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N] + CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + Projection: #MAX(orders.o_custkey) AS __value, alias=__sq_1 [__value:Int64;N] + Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + + assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected); + Ok(()) + } } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 41c75d689f5d..d962dd7b45b9 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -125,7 +125,7 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan { /// # Arguments /// /// * `exprs` - List of expressions that may or may not be joins -/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema +/// * `schema` - HashSet of fully qualified (table.col) fields in subquery schema /// /// # Return value /// @@ -191,7 +191,7 @@ pub fn find_join_exprs( /// # Arguments /// /// * `exprs` - List of expressions that correlate a subquery to an outer scope -/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema +/// * `schema` - subquery schema /// * `include_negated` - true if `NotEq` counts as a join operator /// /// # Return value