Skip to content

Commit

Permalink
simplify the between expr during logical plan optimization (#3404)
Browse files Browse the repository at this point in the history
* rewrite between expression so that it can be further optimized and pushed down

* update tests

* update for comment and test

* fix common_subexpr_eliminate to retain predictable ordering between runs
  • Loading branch information
kmitchener committed Sep 9, 2022
1 parent eaf1d46 commit 73447b5
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 26 deletions.
7 changes: 4 additions & 3 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,12 @@ async fn multiple_or_predicates() -> Result<()> {
let expected =vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
" TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,10 @@ async fn use_between_expression_in_select_query() -> Result<()> {
.unwrap()
.to_string();

// Only test that the projection exprs arecorrect, rather than entire output
// Only test that the projection exprs are correct, rather than entire output
let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 BETWEEN Int64(2) AND Int64(3)]";
assert_contains!(&formatted, needle);
let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)";
let needle = "Projection: #test.c1 >= Int64(2) AND #test.c1 <= Int64(3)";
assert_contains!(&formatted, needle);

Ok(())
Expand Down
20 changes: 10 additions & 10 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use datafusion_expr::{
utils::from_plan,
Expr, ExprSchemable,
};
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

/// A map from expression's identifier to tuple including
Expand Down Expand Up @@ -271,12 +271,12 @@ fn to_arrays(
/// Build the "intermediate" projection plan that evaluates the extracted common expressions.
fn build_project_plan(
input: LogicalPlan,
affected_id: HashSet<Identifier>,
affected_id: BTreeSet<Identifier>,
expr_set: &ExprSet,
) -> Result<LogicalPlan> {
let mut project_exprs = vec![];
let mut fields = vec![];
let mut fields_set = HashSet::new();
let mut fields_set = BTreeSet::new();

for id in affected_id {
match expr_set.get(&id) {
Expand Down Expand Up @@ -320,7 +320,7 @@ fn rewrite_expr(
expr_set: &mut ExprSet,
optimizer_config: &OptimizerConfig,
) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
let mut affected_id = HashSet::<Identifier>::new();
let mut affected_id = BTreeSet::<Identifier>::new();

let rewrote_exprs = exprs_list
.iter()
Expand Down Expand Up @@ -482,7 +482,7 @@ struct CommonSubexprRewriter<'a> {
expr_set: &'a mut ExprSet,
id_array: &'a [(usize, Identifier)],
/// Which identifier is replaced.
affected_id: &'a mut HashSet<Identifier>,
affected_id: &'a mut BTreeSet<Identifier>,

/// the max series number we have rewritten. Other expression nodes
/// with smaller series number is already replaced and shouldn't
Expand Down Expand Up @@ -561,7 +561,7 @@ fn replace_common_expr(
expr: Expr,
id_array: &[(usize, Identifier)],
expr_set: &mut ExprSet,
affected_id: &mut HashSet<Identifier>,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
expr.rewrite(&mut CommonSubexprRewriter {
expr_set,
Expand Down Expand Up @@ -752,7 +752,7 @@ mod test {
#[test]
fn redundant_project_fields() {
let table_scan = test_table_scan().unwrap();
let affected_id: HashSet<Identifier> =
let affected_id: BTreeSet<Identifier> =
["c+a".to_string(), "d+a".to_string()].into_iter().collect();
let expr_set = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
Expand All @@ -764,7 +764,7 @@ mod test {
build_project_plan(table_scan, affected_id.clone(), &expr_set).unwrap();
let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap();

let mut field_set = HashSet::new();
let mut field_set = BTreeSet::new();
for field in project_2.schema().fields() {
assert!(field_set.insert(field.qualified_name()));
}
Expand All @@ -779,7 +779,7 @@ mod test {
.unwrap()
.build()
.unwrap();
let affected_id: HashSet<Identifier> =
let affected_id: BTreeSet<Identifier> =
["c+a".to_string(), "d+a".to_string()].into_iter().collect();
let expr_set = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
Expand All @@ -790,7 +790,7 @@ mod test {
let project = build_project_plan(join, affected_id.clone(), &expr_set).unwrap();
let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap();

let mut field_set = HashSet::new();
let mut field_set = BTreeSet::new();
for field in project_2.schema().fields() {
assert!(field_set.insert(field.qualified_name()));
}
Expand Down
78 changes: 67 additions & 11 deletions datafusion/optimizer/src/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {

/// returns the contained boolean value in `expr` as
/// `Expr::Literal(ScalarValue::Boolean(v))`.
///
/// panics if expr is not a literal boolean
fn as_bool_lit(expr: Expr) -> Result<Option<bool>> {
match expr {
Expr::Literal(ScalarValue::Boolean(v)) => Ok(v),
Expand Down Expand Up @@ -502,7 +500,7 @@ impl<'a> ConstEvaluator<'a> {
ColumnarValue::Array(a) => {
if a.len() != 1 {
Err(DataFusionError::Execution(format!(
"Could not evaluate the expressison, found a result of length {}",
"Could not evaluate the expression, found a result of length {}",
a.len()
)))
} else {
Expand Down Expand Up @@ -803,6 +801,27 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
out_expr.rewrite(self)?
}

//
// Rules for Between
//

// a between 3 and 5 --> a >= 3 AND a <=5
// a not between 3 and 5 --> a < 3 OR a > 5
Between {
expr,
low,
high,
negated,
} => {
if negated {
let l = *expr.clone();
let r = *expr;
or(l.lt(*low), r.gt(*high))
} else {
and(expr.clone().gt_eq(*low), expr.lt_eq(*high))
}
}

expr => {
// no additional rewrites possible
expr
Expand Down Expand Up @@ -1555,8 +1574,13 @@ mod tests {
high: Box::new(lit(10)),
};
let expr = expr.or(lit_bool_null());
let result = simplify(expr.clone());
assert_eq!(expr, result);
let result = simplify(expr);

let expected_expr = or(
and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
lit_bool_null(),
);
assert_eq!(expected_expr, result);
}

#[test]
Expand All @@ -1579,17 +1603,49 @@ mod tests {
assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);

// c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
// it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10`
// and should not be rewritten
// it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
// and the Boolean(NULL) should remain
let expr = Expr::Between {
expr: Box::new(col("c1")),
negated: false,
low: Box::new(lit(0)),
high: Box::new(lit(10)),
};
let expr = expr.and(lit_bool_null());
let result = simplify(expr.clone());
assert_eq!(expr, result);
let result = simplify(expr);

let expected_expr = and(
and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
lit_bool_null(),
);
assert_eq!(expected_expr, result);
}

#[test]
fn simplify_expr_between() {
// c2 between 3 and 4 is c2 >= 3 and c2 <= 4
let expr = Expr::Between {
expr: Box::new(col("c2")),
negated: false,
low: Box::new(lit(3)),
high: Box::new(lit(4)),
};
assert_eq!(
simplify(expr),
and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
);

// c2 not between 3 and 4 is c2 < 3 or c2 > 4
let expr = Expr::Between {
expr: Box::new(col("c2")),
negated: true,
low: Box::new(lit(3)),
high: Box::new(lit(4)),
};
assert_eq!(
simplify(expr),
or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
);
}

// ------------------------------
Expand Down Expand Up @@ -2167,7 +2223,7 @@ mod tests {
.unwrap()
.build()
.unwrap();
let expected = "Filter: #test.d NOT BETWEEN Int32(1) AND Int32(10) AS NOT test.d BETWEEN Int32(1) AND Int32(10)\
let expected = "Filter: #test.d < Int32(1) OR #test.d > Int32(10) AS NOT test.d BETWEEN Int32(1) AND Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -2188,7 +2244,7 @@ mod tests {
.unwrap()
.build()
.unwrap();
let expected = "Filter: #test.d BETWEEN Int32(1) AND Int32(10) AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\
let expected = "Filter: #test.d >= Int32(1) AND #test.d <= Int32(10) AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
Expand Down

0 comments on commit 73447b5

Please sign in to comment.