Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ regex-syntax = "0.8.6"
async-trait = { workspace = true }
criterion = { workspace = true }
ctor = { workspace = true }
datafusion-functions = { workspace = true }
datafusion-functions-aggregate = { workspace = true }
datafusion-functions-window = { workspace = true }
datafusion-functions-window-common = { workspace = true }
Expand Down
243 changes: 240 additions & 3 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion_common::{
assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err,
qualified_name, Column, DFSchema, DataFusionError, Result,
};
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::expr::{Between, InList, ScalarFunction, WindowFunction};
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
use datafusion_expr::utils::{
Expand Down Expand Up @@ -418,6 +418,204 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
predicate
}

/// Tracks coalesce predicates that can be pushed to each side of a FULL JOIN.
struct PushDownCoalesceFilterHelper {
join_keys: Vec<(Column, Column)>,
left_filters: Vec<Expr>,
right_filters: Vec<Expr>,
remaining_filters: Vec<Expr>,
}

impl PushDownCoalesceFilterHelper {
fn new(join_keys: &[(Expr, Expr)]) -> Self {
let join_keys = join_keys
.iter()
.filter_map(|(lhs, rhs)| {
Some((lhs.try_as_col()?.clone(), rhs.try_as_col()?.clone()))
})
.collect();
Self {
join_keys,
left_filters: Vec::new(),
right_filters: Vec::new(),
remaining_filters: Vec::new(),
}
}

fn push_columns<F: FnMut(Expr) -> Expr>(
&mut self,
columns: (Column, Column),
mut build_filter: F,
) {
self.left_filters
.push(build_filter(Expr::Column(columns.0)));
self.right_filters
.push(build_filter(Expr::Column(columns.1)));
}

fn extract_join_columns(&self, expr: &Expr) -> Option<(Column, Column)> {
if let Expr::ScalarFunction(ScalarFunction { func, args }) = expr {
if func.name() != "coalesce" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems very specific to coalesce and will likely break anyone who provides their own implementation of coalesce that overrides the built in one

Can we formualte this as some more general property of the function that allows pushing down? That way we could mark coalesce as having this property

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the fact that coalesce is basically identity function over kept side of the join allows such optimization...

I will add end-to-end test tomorrow. I have a problem on my data where I need this optimization, so I will try to reproduce it in end-to-end test.

return None;
}
if let [Expr::Column(lhs), Expr::Column(rhs)] = args.as_slice() {
for (join_lhs, join_rhs) in &self.join_keys {
if join_lhs == lhs && join_rhs == rhs {
return Some((lhs.clone(), rhs.clone()));
}
if join_lhs == rhs && join_rhs == lhs {
return Some((rhs.clone(), lhs.clone()));
}
}
}
}
None
}

fn push_term(&mut self, term: &Expr) {
match term {
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op.supports_propagation() =>
{
if let Some(columns) = self.extract_join_columns(left) {
return self.push_columns(columns, |replacement| {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(replacement),
op: *op,
right: right.clone(),
})
});
}
if let Some(columns) = self.extract_join_columns(right) {
return self.push_columns(columns, |replacement| {
Expr::BinaryExpr(BinaryExpr {
left: left.clone(),
op: *op,
right: Box::new(replacement),
})
});
}
}
Expr::IsNull(expr) => {
if let Some(columns) = self.extract_join_columns(expr) {
return self.push_columns(columns, |replacement| {
Expr::IsNull(Box::new(replacement))
});
}
}
Expr::IsNotNull(expr) => {
if let Some(columns) = self.extract_join_columns(expr) {
return self.push_columns(columns, |replacement| {
Expr::IsNotNull(Box::new(replacement))
});
}
}
Expr::IsTrue(expr) => {
if let Some(columns) = self.extract_join_columns(expr) {
return self.push_columns(columns, |replacement| {
Expr::IsTrue(Box::new(replacement))
});
}
}
Expr::IsFalse(expr) => {
if let Some(columns) = self.extract_join_columns(expr) {
return self.push_columns(columns, |replacement| {
Expr::IsFalse(Box::new(replacement))
});
}
}
Expr::IsUnknown(expr) => {
if let Some(columns) = self.extract_join_columns(expr) {
return self.push_columns(columns, |replacement| {
Expr::IsUnknown(Box::new(replacement))
});
}
}
Expr::IsNotTrue(expr) => {
if let Some(columns) = self.extract_join_columns(expr) {
return self.push_columns(columns, |replacement| {
Expr::IsNotTrue(Box::new(replacement))
});
}
}
Expr::IsNotFalse(expr) => {
if let Some(columns) = self.extract_join_columns(expr) {
return self.push_columns(columns, |replacement| {
Expr::IsNotFalse(Box::new(replacement))
});
}
}
Expr::IsNotUnknown(expr) => {
if let Some(columns) = self.extract_join_columns(expr) {
return self.push_columns(columns, |replacement| {
Expr::IsNotUnknown(Box::new(replacement))
});
}
}
Expr::Between(between) => {
if let Some(columns) = self.extract_join_columns(&between.expr) {
return self.push_columns(columns, |replacement| {
Expr::Between(Between {
expr: Box::new(replacement),
negated: between.negated,
low: between.low.clone(),
high: between.high.clone(),
})
});
}
}
Expr::InList(in_list) => {
if let Some(columns) = self.extract_join_columns(&in_list.expr) {
return self.push_columns(columns, |replacement| {
Expr::InList(InList {
expr: Box::new(replacement),
list: in_list.list.clone(),
negated: in_list.negated,
})
});
}
}
_ => {}
}
self.remaining_filters.push(term.clone());
}

fn push_predicate(
mut self,
predicate: Expr,
) -> Result<(Option<Expr>, Option<Expr>, Vec<Expr>)> {
let predicates = split_conjunction_owned(predicate);
let terms = simplify_predicates(predicates)?;
for term in terms {
self.push_term(&term);
}
Ok((
conjunction(self.left_filters),
conjunction(self.right_filters),
self.remaining_filters,
))
}
}

fn push_full_join_coalesce_filters(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to push filters into the inputs of a FULL JOIN , you need to guarantee that the join doens't reintroduce rows (with nulls) that would have been filtered if the filter was applied beforehand

In other words, it is not clear to me that this optimization is correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to change the name. I am using this optimization in my code specifically for chains (up to 50-table long) of FULL OUTER JOINs. I am making joins with a sequence join->project with coalesce over join keys -> alias, like:

let plan = LogicalPlanBuilder::scan("t1", scan.clone(), None)?
    .join(
        LogicalPlanBuilder::scan("t2", scan.clone(), None)?.build()?,
        JoinType::Full,
        (vec!["a"], vec!["a"]),
        None,
    )?
    .project(vec![
        coalesce(vec![col("t1.a"), col("t2.a")]).alias("a"),
        col("t1.b").alias("b1"),
        col("t2.b").alias("b2"),
    ])?
    .alias("j1")?
    .build()?;

This way the initial data which looks like

{
  "table1": {
    "1": 100,
    "2": 200,
    "3": 300
  },
  "table2": {
    "2": 2000,
    "3": 3000,
    "4": 4000
  },
  "table3": {
    "3": 30000,
    "4": 40000,
    "5": 50000
}

is joined into

key table1 table2 table3
1 100 null null
2 200 2000 null
3 300 3000 30000
4 null 4000 40000
5 null null 50000

instead of

key1 key2 key3 table1 table2 table3
1 null null 100 null null
2 2 null 200 2000 null
3 3 3 300 3000 30000
null 4 4 null 4000 40000
null null 5 null null 50000

You can check the illustration https://docs.platforma.bio/guides/vdj-analysis/diversity-analysis/#results-table where different sample properties are joined by Sample Id from different parquet files.

When I apply filter on Key what I effectively want is to replicate this filter to all input tables. And optimization that I provided does exactly that.

I am applying the chain join->project with coalesce over join keys -> alias for each new table, so for 50 tables I would have 49 projections with coalesce. Without my optimization, each optimizer pass has simplification which turns coalesce into CASE and then performs push-down which again turns case to coalesce. So 1 optimizer pass gives me propagation through 1 layer, and for 50 tables I would have to have 49 optimizer passes for full propagation. The optimization in this PR allows to optimize such scenario in 1 optimizer pass.

I realized that this optimization seems correct for any type of join if coalesce is applied to the join keys, so I do not have explicit check for FULL OUTER JOIN in proposed code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you believe that nobody else has the scenario I described, this way we can simply close the PR and issue without further discussion)

join: &mut Join,
predicate: Expr,
) -> Result<Option<Vec<Expr>>> {
let (Some(left), Some(right), remaining) =
PushDownCoalesceFilterHelper::new(&join.on).push_predicate(predicate)?
else {
return Ok(None);
};

let left_input = Arc::clone(&join.left);
join.left = Arc::new(make_filter(left, left_input)?);

let right_input = Arc::clone(&join.right);
join.right = Arc::new(make_filter(right, right_input)?);

Ok(Some(remaining))
}

/// push down join/cross-join
fn push_down_all_join(
predicates: Vec<Expr>,
Expand Down Expand Up @@ -527,13 +725,21 @@ fn push_down_all_join(
}

fn push_down_join(
join: Join,
mut join: Join,
parent_predicate: Option<&Expr>,
) -> Result<Transformed<LogicalPlan>> {
// Split the parent predicate into individual conjunctive parts.
let predicates = parent_predicate
let mut predicates = parent_predicate
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));

if let Some(parent_predicate) = parent_predicate {
if let Some(remaining_predicates) =
push_full_join_coalesce_filters(&mut join, parent_predicate.clone())?
{
predicates = remaining_predicates;
}
}

// Extract conjunctions from the JOIN's ON filter, if present.
let on_filters = join
.filter
Expand Down Expand Up @@ -1447,6 +1653,7 @@ mod tests {
use crate::test::*;
use crate::OptimizerContext;
use datafusion_expr::test::function_stub::sum;
use datafusion_functions::core::expr_fn::coalesce;
use insta::assert_snapshot;

use super::*;
Expand Down Expand Up @@ -2848,6 +3055,36 @@ mod tests {
)
}

/// Filter on coalesce of join keys should be pushed to both join inputs
#[test]
fn filter_full_join_on_coalesce() -> Result<()> {
let table_scan_t1 = test_table_scan_with_name("t1")?;
let table_scan_t2 = test_table_scan_with_name("t2")?;

let plan = LogicalPlanBuilder::from(table_scan_t1)
.join(table_scan_t2, JoinType::Full, (vec!["a"], vec!["a"]), None)?
.filter(coalesce(vec![col("t1.a"), col("t2.a")]).eq(lit(1i32)))?
.build()?;

// not part of the test, just good to know:
assert_snapshot!(plan,
@r"
Filter: coalesce(t1.a, t2.a) = Int32(1)
Full Join: t1.a = t2.a
TableScan: t1
TableScan: t2
",
);
assert_optimized_plan_equal!(
plan,
@r"
Full Join: t1.a = t2.a
TableScan: t1, full_filters=[t1.a = Int32(1)]
TableScan: t2, full_filters=[t2.a = Int32(1)]
"
)
}

/// join filter should be completely removed after pushdown
#[test]
fn join_filter_removed() -> Result<()> {
Expand Down