diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 0782302d736e..afa5592636ab 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -419,6 +419,18 @@ impl LogicalPlanBuilder { Ok(Self::from(union_with_alias(self.plan.clone(), plan, None)?)) } + pub fn union_with_alias( + &self, + plan: LogicalPlan, + alias: Option, + ) -> Result { + Ok(Self::from(union_with_alias( + self.plan.clone(), + plan, + alias, + )?)) + } + /// Apply a union, removing duplicate rows pub fn union_distinct(&self, plan: LogicalPlan) -> Result { // unwrap top-level Distincts, to avoid duplication diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6ad643fd1ae3..0880916b5fc7 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1145,6 +1145,26 @@ impl Projection { }) } + /// Create a new Projection using the specified output schema + pub fn new_from_schema( + input: Arc, + schema: DFSchemaRef, + alias: Option, + ) -> Self { + let expr: Vec = schema + .fields() + .iter() + .map(|field| field.qualified_column()) + .map(Expr::Column) + .collect(); + Self { + expr, + input, + schema, + alias, + } + } + pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Projection> { match plan { LogicalPlan::Projection(it) => Ok(it), diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 13d4cf4a328a..e62cbbd73103 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -26,6 +26,7 @@ pub mod inline_table_scan; pub mod limit_push_down; pub mod optimizer; pub mod projection_push_down; +pub mod propagate_empty_relation; pub mod reduce_cross_join; pub mod reduce_outer_join; pub mod scalar_subquery_to_join; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 315f47499f78..c508d8508e01 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -27,6 +27,7 @@ use crate::filter_push_down::FilterPushDown; use crate::inline_table_scan::InlineTableScan; use crate::limit_push_down::LimitPushDown; use crate::projection_push_down::ProjectionPushDown; +use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::reduce_cross_join::ReduceCrossJoin; use crate::reduce_outer_join::ReduceOuterJoin; use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; @@ -165,6 +166,7 @@ impl Optimizer { Arc::new(ReduceCrossJoin::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), + Arc::new(PropagateEmptyRelation::new()), Arc::new(RewriteDisjunctivePredicate::new()), ]; if config.filter_null_keys { diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs new file mode 100644 index 000000000000..59f88cef6716 --- /dev/null +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -0,0 +1,412 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; +use std::sync::Arc; + +use crate::{utils, OptimizerConfig, OptimizerRule}; + +/// Optimization rule that bottom-up to eliminate plan by propagating empty_relation. +#[derive(Default)] +pub struct PropagateEmptyRelation; + +impl PropagateEmptyRelation { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for PropagateEmptyRelation { + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &mut OptimizerConfig, + ) -> Result { + // optimize child plans first + let optimized_children_plan = + utils::optimize_children(self, plan, optimizer_config)?; + match &optimized_children_plan { + LogicalPlan::EmptyRelation(_) => Ok(optimized_children_plan), + LogicalPlan::Projection(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Window(_) + | LogicalPlan::Sort(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Limit(_) => match empty_child(&optimized_children_plan)? { + Some(empty) => Ok(empty), + None => Ok(optimized_children_plan), + }, + LogicalPlan::CrossJoin(_) => { + let (left_empty, right_empty) = + binary_plan_children_is_empty(&optimized_children_plan)?; + if left_empty || right_empty { + Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: optimized_children_plan.schema().clone(), + })) + } else { + Ok(optimized_children_plan) + } + } + LogicalPlan::Join(join) => { + // TODO: For Join, more join type need to be careful: + // For LeftOuter/LeftSemi/LeftAnti Join, only the left side is empty, the Join result is empty. + // For LeftSemi Join, if the right side is empty, the Join result is empty. + // For LeftAnti Join, if the right side is empty, the Join result is left side(should exclude null ??). + // For RightOuter/RightSemi/RightAnti Join, only the right side is empty, the Join result is empty. + // For RightSemi Join, if the left side is empty, the Join result is empty. + // For RightAnti Join, if the left side is empty, the Join result is right side(should exclude null ??). + // For Full Join, only both sides are empty, the Join result is empty. + // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side + // columns + right side columns replaced with null values. + // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side + // columns + left side columns replaced with null values. + if join.join_type == JoinType::Inner { + let (left_empty, right_empty) = + binary_plan_children_is_empty(&optimized_children_plan)?; + if left_empty || right_empty { + Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: optimized_children_plan.schema().clone(), + })) + } else { + Ok(optimized_children_plan) + } + } else { + Ok(optimized_children_plan) + } + } + LogicalPlan::Union(union) => { + let new_inputs = union + .inputs + .iter() + .filter(|input| match &***input { + LogicalPlan::EmptyRelation(empty) => empty.produce_one_row, + _ => true, + }) + .cloned() + .collect::>(); + + if new_inputs.len() == union.inputs.len() { + Ok(optimized_children_plan) + } else if new_inputs.is_empty() { + Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: optimized_children_plan.schema().clone(), + })) + } else if new_inputs.len() == 1 { + let child = (**(union.inputs.get(0).unwrap())).clone(); + if child.schema().eq(optimized_children_plan.schema()) { + Ok(child) + } else { + Ok(LogicalPlan::Projection(Projection::new_from_schema( + Arc::new(child), + optimized_children_plan.schema().clone(), + union.alias.clone(), + ))) + } + } else { + Ok(LogicalPlan::Union(Union { + inputs: new_inputs, + schema: union.schema.clone(), + alias: union.alias.clone(), + })) + } + } + LogicalPlan::Aggregate(agg) => { + if !agg.group_expr.is_empty() { + match empty_child(&optimized_children_plan)? { + Some(empty) => Ok(empty), + None => Ok(optimized_children_plan), + } + } else { + Ok(optimized_children_plan) + } + } + _ => Ok(optimized_children_plan), + } + } + + fn name(&self) -> &str { + "propagate_empty_relation" + } +} + +fn binary_plan_children_is_empty(plan: &LogicalPlan) -> Result<(bool, bool)> { + let inputs = plan.inputs(); + + // all binary-plan need to deal with separately. + match inputs.len() { + 2 => { + let left = inputs.get(0).unwrap(); + let right = inputs.get(1).unwrap(); + + let left_empty = match left { + LogicalPlan::EmptyRelation(empty) => !empty.produce_one_row, + _ => false, + }; + let right_empty = match right { + LogicalPlan::EmptyRelation(empty) => !empty.produce_one_row, + _ => false, + }; + Ok((left_empty, right_empty)) + } + _ => Err(DataFusionError::Plan( + "plan just can have two child".to_string(), + )), + } +} + +fn empty_child(plan: &LogicalPlan) -> Result> { + let inputs = plan.inputs(); + + // all binary-plan need to deal with separately. + match inputs.len() { + 1 => { + let input = inputs.get(0).unwrap(); + match input { + LogicalPlan::EmptyRelation(empty) => { + if !empty.produce_one_row { + Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: plan.schema().clone(), + }))) + } else { + Ok(None) + } + } + _ => Ok(None), + } + } + _ => Err(DataFusionError::Plan( + "plan just can have one child".to_string(), + )), + } +} + +#[cfg(test)] +mod tests { + use crate::eliminate_filter::EliminateFilter; + use crate::test::{test_table_scan, test_table_scan_with_name}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Column, ScalarValue}; + use datafusion_expr::logical_plan::table_scan; + use datafusion_expr::{ + binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, JoinType, + Operator, + }; + + use super::*; + + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + let rule = PropagateEmptyRelation::new(); + let optimized_plan = rule + .optimize(plan, &mut OptimizerConfig::new()) + .expect("failed to optimize plan"); + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + assert_eq!(plan.schema(), optimized_plan.schema()); + } + + fn assert_together_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + let optimize_one = EliminateFilter::new() + .optimize(plan, &mut OptimizerConfig::new()) + .expect("failed to optimize plan"); + let optimize_two = PropagateEmptyRelation::new() + .optimize(&optimize_one, &mut OptimizerConfig::new()) + .expect("failed to optimize plan"); + let formatted_plan = format!("{:?}", optimize_two); + assert_eq!(formatted_plan, expected); + assert_eq!(plan.schema(), optimize_two.schema()); + } + + #[test] + fn propagate_empty() -> Result<()> { + let plan = LogicalPlanBuilder::empty(false) + .filter(Expr::Literal(ScalarValue::Boolean(Some(true))))? + .limit(10, None)? + .project(vec![binary_expr(lit(1), Operator::Plus, lit(1))])? + .build()?; + + let expected = "EmptyRelation"; + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn cooperate_with_eliminate_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a")])? + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .join_using( + &right, + JoinType::Inner, + vec![Column::from_name("a".to_string())], + )? + .filter(col("a").lt_eq(lit(1i64)))? + .build()?; + + let expected = "EmptyRelation"; + assert_together_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn propagate_union_empty() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?).build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + + let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; + + let expected = "Projection: a, b, c\ + \n TableScan: test"; + assert_together_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn propagate_union_multi_empty() -> Result<()> { + let one = + LogicalPlanBuilder::from(test_table_scan_with_name("test1")?).build()?; + let two = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + let three = LogicalPlanBuilder::from(test_table_scan_with_name("test3")?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + let four = + LogicalPlanBuilder::from(test_table_scan_with_name("test4")?).build()?; + + let plan = LogicalPlanBuilder::from(one) + .union(two)? + .union(three)? + .union(four)? + .build()?; + + let expected = "Union\ + \n TableScan: test1\ + \n TableScan: test4"; + assert_together_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn propagate_union_all_empty() -> Result<()> { + let one = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + let two = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + let three = LogicalPlanBuilder::from(test_table_scan_with_name("test3")?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + let four = LogicalPlanBuilder::from(test_table_scan_with_name("test4")?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + + let plan = LogicalPlanBuilder::from(one) + .union(two)? + .union(three)? + .union(four)? + .build()?; + + let expected = "EmptyRelation"; + assert_together_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn propagate_union_children_different_schema() -> Result<()> { + let one_schema = Schema::new(vec![Field::new("t1a", DataType::UInt32, false)]); + let t1_scan = table_scan(Some("test1"), &one_schema, None)?.build()?; + let one = LogicalPlanBuilder::from(t1_scan) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + + let two_schema = Schema::new(vec![Field::new("t2a", DataType::UInt32, false)]); + let t2_scan = table_scan(Some("test2"), &two_schema, None)?.build()?; + let two = LogicalPlanBuilder::from(t2_scan).build()?; + + let three_schema = Schema::new(vec![Field::new("t3a", DataType::UInt32, false)]); + let t3_scan = table_scan(Some("test3"), &three_schema, None)?.build()?; + let three = LogicalPlanBuilder::from(t3_scan).build()?; + + let plan = LogicalPlanBuilder::from(one) + .union(two)? + .union(three)? + .build()?; + + let expected = "Union\ + \n TableScan: test2\ + \n TableScan: test3"; + assert_together_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn propagate_union_alias() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?).build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_with_alias(right, Some("union".to_string()))? + .build()?; + + let expected = "Projection: union.a, union.b, union.c, alias=union\ + \n TableScan: test"; + assert_together_optimized_plan_eq(&plan, expected); + Ok(()) + } + + #[test] + fn cross_join_empty() -> Result<()> { + let table_scan = test_table_scan()?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right = LogicalPlanBuilder::empty(false).build()?; + + let plan = LogicalPlanBuilder::from(left) + .cross_join(&right)? + .filter(col("a").lt_eq(lit(1i64)))? + .build()?; + + let expected = "EmptyRelation"; + assert_together_optimized_plan_eq(&plan, expected); + + Ok(()) + } +} diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 779f156c035f..7d092b00bba6 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -258,6 +258,15 @@ fn timestamp_nano_ts_utc_predicates() { assert_eq!(expected, format!("{:?}", plan)); } +#[test] +fn propagate_empty_relation() { + let sql = "SELECT col_int32 FROM test JOIN ( SELECT col_int32 FROM test WHERE false ) AS ta1 ON test.col_int32 = ta1.col_int32;"; + let plan = test_sql(sql).unwrap(); + // when children exist EmptyRelation, it will bottom-up propagate. + let expected = "EmptyRelation"; + assert_eq!(expected, format!("{:?}", plan)); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...