diff --git a/Cargo.lock b/Cargo.lock index e0f811f16f1b..8a2c1faaa149 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2456,6 +2456,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-expr-common", + "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-window", "datafusion-functions-window-common", diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 15d3261ca513..7163d9566c01 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -61,6 +61,7 @@ regex-syntax = "0.8.6" async-trait = { workspace = true } criterion = { workspace = true } ctor = { workspace = true } +datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } datafusion-functions-window-common = { workspace = true } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a38cd7a75bc1..15456d2c1ea4 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -31,7 +31,7 @@ use datafusion_common::{ assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err, qualified_name, Column, DFSchema, DataFusionError, Result, }; -use datafusion_expr::expr::WindowFunction; +use datafusion_expr::expr::{Between, InList, ScalarFunction, WindowFunction}; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; use datafusion_expr::utils::{ @@ -418,6 +418,204 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option, + left_filters: Vec, + right_filters: Vec, + remaining_filters: Vec, +} + +impl PushDownCoalesceFilterHelper { + fn new(join_keys: &[(Expr, Expr)]) -> Self { + let join_keys = join_keys + .iter() + .filter_map(|(lhs, rhs)| { + Some((lhs.try_as_col()?.clone(), rhs.try_as_col()?.clone())) + }) + .collect(); + Self { + join_keys, + left_filters: Vec::new(), + right_filters: Vec::new(), + remaining_filters: Vec::new(), + } + } + + fn push_columns Expr>( + &mut self, + columns: (Column, Column), + mut build_filter: F, + ) { + self.left_filters + .push(build_filter(Expr::Column(columns.0))); + self.right_filters + .push(build_filter(Expr::Column(columns.1))); + } + + fn extract_join_columns(&self, expr: &Expr) -> Option<(Column, Column)> { + if let Expr::ScalarFunction(ScalarFunction { func, args }) = expr { + if func.name() != "coalesce" { + return None; + } + if let [Expr::Column(lhs), Expr::Column(rhs)] = args.as_slice() { + for (join_lhs, join_rhs) in &self.join_keys { + if join_lhs == lhs && join_rhs == rhs { + return Some((lhs.clone(), rhs.clone())); + } + if join_lhs == rhs && join_rhs == lhs { + return Some((rhs.clone(), lhs.clone())); + } + } + } + } + None + } + + fn push_term(&mut self, term: &Expr) { + match term { + Expr::BinaryExpr(BinaryExpr { left, op, right }) + if op.supports_propagation() => + { + if let Some(columns) = self.extract_join_columns(left) { + return self.push_columns(columns, |replacement| { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(replacement), + op: *op, + right: right.clone(), + }) + }); + } + if let Some(columns) = self.extract_join_columns(right) { + return self.push_columns(columns, |replacement| { + Expr::BinaryExpr(BinaryExpr { + left: left.clone(), + op: *op, + right: Box::new(replacement), + }) + }); + } + } + Expr::IsNull(expr) => { + if let Some(columns) = self.extract_join_columns(expr) { + return self.push_columns(columns, |replacement| { + Expr::IsNull(Box::new(replacement)) + }); + } + } + Expr::IsNotNull(expr) => { + if let Some(columns) = self.extract_join_columns(expr) { + return self.push_columns(columns, |replacement| { + Expr::IsNotNull(Box::new(replacement)) + }); + } + } + Expr::IsTrue(expr) => { + if let Some(columns) = self.extract_join_columns(expr) { + return self.push_columns(columns, |replacement| { + Expr::IsTrue(Box::new(replacement)) + }); + } + } + Expr::IsFalse(expr) => { + if let Some(columns) = self.extract_join_columns(expr) { + return self.push_columns(columns, |replacement| { + Expr::IsFalse(Box::new(replacement)) + }); + } + } + Expr::IsUnknown(expr) => { + if let Some(columns) = self.extract_join_columns(expr) { + return self.push_columns(columns, |replacement| { + Expr::IsUnknown(Box::new(replacement)) + }); + } + } + Expr::IsNotTrue(expr) => { + if let Some(columns) = self.extract_join_columns(expr) { + return self.push_columns(columns, |replacement| { + Expr::IsNotTrue(Box::new(replacement)) + }); + } + } + Expr::IsNotFalse(expr) => { + if let Some(columns) = self.extract_join_columns(expr) { + return self.push_columns(columns, |replacement| { + Expr::IsNotFalse(Box::new(replacement)) + }); + } + } + Expr::IsNotUnknown(expr) => { + if let Some(columns) = self.extract_join_columns(expr) { + return self.push_columns(columns, |replacement| { + Expr::IsNotUnknown(Box::new(replacement)) + }); + } + } + Expr::Between(between) => { + if let Some(columns) = self.extract_join_columns(&between.expr) { + return self.push_columns(columns, |replacement| { + Expr::Between(Between { + expr: Box::new(replacement), + negated: between.negated, + low: between.low.clone(), + high: between.high.clone(), + }) + }); + } + } + Expr::InList(in_list) => { + if let Some(columns) = self.extract_join_columns(&in_list.expr) { + return self.push_columns(columns, |replacement| { + Expr::InList(InList { + expr: Box::new(replacement), + list: in_list.list.clone(), + negated: in_list.negated, + }) + }); + } + } + _ => {} + } + self.remaining_filters.push(term.clone()); + } + + fn push_predicate( + mut self, + predicate: Expr, + ) -> Result<(Option, Option, Vec)> { + let predicates = split_conjunction_owned(predicate); + let terms = simplify_predicates(predicates)?; + for term in terms { + self.push_term(&term); + } + Ok(( + conjunction(self.left_filters), + conjunction(self.right_filters), + self.remaining_filters, + )) + } +} + +fn push_full_join_coalesce_filters( + join: &mut Join, + predicate: Expr, +) -> Result>> { + let (Some(left), Some(right), remaining) = + PushDownCoalesceFilterHelper::new(&join.on).push_predicate(predicate)? + else { + return Ok(None); + }; + + let left_input = Arc::clone(&join.left); + join.left = Arc::new(make_filter(left, left_input)?); + + let right_input = Arc::clone(&join.right); + join.right = Arc::new(make_filter(right, right_input)?); + + Ok(Some(remaining)) +} + /// push down join/cross-join fn push_down_all_join( predicates: Vec, @@ -527,13 +725,21 @@ fn push_down_all_join( } fn push_down_join( - join: Join, + mut join: Join, parent_predicate: Option<&Expr>, ) -> Result> { // Split the parent predicate into individual conjunctive parts. - let predicates = parent_predicate + let mut predicates = parent_predicate .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone())); + if let Some(parent_predicate) = parent_predicate { + if let Some(remaining_predicates) = + push_full_join_coalesce_filters(&mut join, parent_predicate.clone())? + { + predicates = remaining_predicates; + } + } + // Extract conjunctions from the JOIN's ON filter, if present. let on_filters = join .filter @@ -1447,6 +1653,7 @@ mod tests { use crate::test::*; use crate::OptimizerContext; use datafusion_expr::test::function_stub::sum; + use datafusion_functions::core::expr_fn::coalesce; use insta::assert_snapshot; use super::*; @@ -2848,6 +3055,36 @@ mod tests { ) } + /// Filter on coalesce of join keys should be pushed to both join inputs + #[test] + fn filter_full_join_on_coalesce() -> Result<()> { + let table_scan_t1 = test_table_scan_with_name("t1")?; + let table_scan_t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(table_scan_t1) + .join(table_scan_t2, JoinType::Full, (vec!["a"], vec!["a"]), None)? + .filter(coalesce(vec![col("t1.a"), col("t2.a")]).eq(lit(1i32)))? + .build()?; + + // not part of the test, just good to know: + assert_snapshot!(plan, + @r" + Filter: coalesce(t1.a, t2.a) = Int32(1) + Full Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Full Join: t1.a = t2.a + TableScan: t1, full_filters=[t1.a = Int32(1)] + TableScan: t2, full_filters=[t2.a = Int32(1)] + " + ) + } + /// join filter should be completely removed after pushdown #[test] fn join_filter_removed() -> Result<()> {