Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 71 additions & 19 deletions datafusion/optimizer/src/decorrelate_scalar_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,17 @@ impl OptimizerRule for DecorrelateScalarSubquery {

let (subqueries, other_exprs) =
self.extract_subquery_exprs(predicate, optimizer_config)?;
let optimized_plan = LogicalPlan::Filter(Filter {
predicate: predicate.clone(),
input: Arc::new(optimized_input),
});

if subqueries.is_empty() {
// regular filter, no subquery exists clause here
let optimized_plan = LogicalPlan::Filter(Filter {
predicate: predicate.clone(),
input: Arc::new(optimized_input),
});
return Ok(optimized_plan);
}

// iterate through all exists clauses in predicate, turning each into a join
// iterate through all subqueries in predicate, turning each into a join
let mut cur_input = (**input).clone();
for subquery in subqueries {
cur_input = optimize_scalar(
Expand All @@ -136,22 +137,39 @@ impl OptimizerRule for DecorrelateScalarSubquery {

/// Takes a query like:
///
/// ```select id from customers where balance >
/// ```text
/// select id from customers where balance >
/// (select avg(total) from orders where orders.c_id = customers.id)
/// ```
///
/// and optimizes it into:
///
/// ```select c.id from customers c
/// ```text
/// select c.id from customers c
/// inner join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
/// where c.balance > o.val```
/// where c.balance > o.val
/// ```
///
/// Or a query like:
///
/// ```text
/// select id from customers where balance >
/// (select avg(total) from orders)
/// ```
///
/// and optimizes it into:
///
/// ```text
/// select c.id from customers c
/// cross join (select avg(total) as val from orders) a
/// where c.balance > a.val
/// ```
///
/// # Arguments
///
/// * `subqry` - The subquery portion of the `where exists` (select * from orders)
/// * `negated` - True if the subquery is a `where not exists`
/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
/// * `filter_input` - The non-subquery portion (from customers)
/// * `other_filter_exprs` - Any additional parts to the `where` expression (and c.x = y)
/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
/// * `optimizer_config` - Used to generate unique subquery aliases
fn optimize_scalar(
query_info: &SubqueryInfo,
Expand All @@ -173,20 +191,27 @@ fn optimize_scalar(
.map_err(|e| context!("Exactly one input is expected. Is this a join?", e))?;
let aggr = Aggregate::try_from_plan(sub_input)
.map_err(|e| context!("scalar subqueries must aggregate a value", e))?;
let filter = Filter::try_from_plan(&aggr.input).map_err(|e| {
context!("scalar subqueries must have a filter to be correlated", e)
})?;
let filter = Filter::try_from_plan(&aggr.input).ok();

// split into filters
// if there were filters, we use that logical plan, otherwise the plan from the aggregate
let input = if let Some(filter) = filter {
&filter.input
} else {
&aggr.input
};

// if there were filters, split and capture them
let mut subqry_filter_exprs = vec![];
split_conjunction(&filter.predicate, &mut subqry_filter_exprs);
if let Some(filter) = filter {
split_conjunction(&filter.predicate, &mut subqry_filter_exprs);
}
verify_not_disjunction(&subqry_filter_exprs)?;

// Grab column names to join on
let (col_exprs, other_subqry_exprs) =
find_join_exprs(subqry_filter_exprs, filter.input.schema())?;
find_join_exprs(subqry_filter_exprs, input.schema())?;
let (outer_cols, subqry_cols, join_filters) =
exprs_to_join_cols(&col_exprs, filter.input.schema(), false)?;
exprs_to_join_cols(&col_exprs, input.schema(), false)?;
if join_filters.is_some() {
plan_err!("only joins on column equality are presently supported")?;
}
Expand All @@ -199,7 +224,7 @@ fn optimize_scalar(
.collect();

// build subquery side of join - the thing the subquery was querying
let mut subqry_plan = LogicalPlanBuilder::from((*filter.input).clone());
let mut subqry_plan = LogicalPlanBuilder::from((**input).clone());
if let Some(expr) = combine_filters(&other_subqry_exprs) {
subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them
}
Expand Down Expand Up @@ -702,4 +727,31 @@ mod tests {
assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
Ok(())
}

/// Test for non-correlated scalar subquery with no filters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename this if it also does uncorrelated subqueries now? Maybe scalar_subquery_to_join?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean rename the whole rule? Probably so .. good idea. You want to do that as part of this issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to, but we probably should. Maybe we can just doc it for now with a TODO and address it in a future PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It touches a bunch of code, but yes, your suggested name is much more clear. I'll do a separate PR for it this week if this gets merged in.

#[test]
fn scalar_subquery_non_correlated_no_filters() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like 1 branch was added, and 1 test. LGTM!

let sq = Arc::new(
LogicalPlanBuilder::from(scan_tpch_table("orders"))
.aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
.project(vec![max(col("orders.o_custkey"))])?
.build()?,
);

let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
.filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
.project(vec![col("customer.c_custkey")])?
.build()?;

let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]
Projection: #MAX(orders.o_custkey) AS __value, alias=__sq_1 [__value:Int64;N]
Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;

assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
Ok(())
}
}
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan {
/// # Arguments
///
/// * `exprs` - List of expressions that may or may not be joins
/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema
/// * `schema` - HashSet of fully qualified (table.col) fields in subquery schema
///
/// # Return value
///
Expand Down Expand Up @@ -191,7 +191,7 @@ pub fn find_join_exprs(
/// # Arguments
///
/// * `exprs` - List of expressions that correlate a subquery to an outer scope
/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema
/// * `schema` - subquery schema
/// * `include_negated` - true if `NotEq` counts as a join operator
///
/// # Return value
Expand Down