Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,12 @@ async fn multiple_or_predicates() -> Result<()> {
let expected =vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
" TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great optimization. Next goal is to get it to see that part.p_size <= 15 as well :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, that would be excellent. For another PR if someone doesn't beat me to it :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe sooner .. looks like the order of the projections changes between runs, which is causing the test failure here.

" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,10 @@ async fn use_between_expression_in_select_query() -> Result<()> {
.unwrap()
.to_string();

// Only test that the projection exprs arecorrect, rather than entire output
// Only test that the projection exprs are correct, rather than entire output
let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1 BETWEEN Int64(2) AND Int64(3)]";
assert_contains!(&formatted, needle);
let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)";
let needle = "Projection: #test.c1 >= Int64(2) AND #test.c1 <= Int64(3)";
assert_contains!(&formatted, needle);

Ok(())
Expand Down
20 changes: 10 additions & 10 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use datafusion_expr::{
utils::from_plan,
Expr, ExprSchemable,
};
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

/// A map from expression's identifier to tuple including
Expand Down Expand Up @@ -271,12 +271,12 @@ fn to_arrays(
/// Build the "intermediate" projection plan that evaluates the extracted common expressions.
fn build_project_plan(
input: LogicalPlan,
affected_id: HashSet<Identifier>,
affected_id: BTreeSet<Identifier>,
expr_set: &ExprSet,
) -> Result<LogicalPlan> {
let mut project_exprs = vec![];
let mut fields = vec![];
let mut fields_set = HashSet::new();
let mut fields_set = BTreeSet::new();

for id in affected_id {
match expr_set.get(&id) {
Expand Down Expand Up @@ -320,7 +320,7 @@ fn rewrite_expr(
expr_set: &mut ExprSet,
optimizer_config: &OptimizerConfig,
) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
let mut affected_id = HashSet::<Identifier>::new();
let mut affected_id = BTreeSet::<Identifier>::new();

let rewrote_exprs = exprs_list
.iter()
Expand Down Expand Up @@ -482,7 +482,7 @@ struct CommonSubexprRewriter<'a> {
expr_set: &'a mut ExprSet,
id_array: &'a [(usize, Identifier)],
/// Which identifier is replaced.
affected_id: &'a mut HashSet<Identifier>,
affected_id: &'a mut BTreeSet<Identifier>,

/// the max series number we have rewritten. Other expression nodes
/// with smaller series number is already replaced and shouldn't
Expand Down Expand Up @@ -561,7 +561,7 @@ fn replace_common_expr(
expr: Expr,
id_array: &[(usize, Identifier)],
expr_set: &mut ExprSet,
affected_id: &mut HashSet<Identifier>,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
expr.rewrite(&mut CommonSubexprRewriter {
expr_set,
Expand Down Expand Up @@ -752,7 +752,7 @@ mod test {
#[test]
fn redundant_project_fields() {
let table_scan = test_table_scan().unwrap();
let affected_id: HashSet<Identifier> =
let affected_id: BTreeSet<Identifier> =
["c+a".to_string(), "d+a".to_string()].into_iter().collect();
let expr_set = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
Expand All @@ -764,7 +764,7 @@ mod test {
build_project_plan(table_scan, affected_id.clone(), &expr_set).unwrap();
let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap();

let mut field_set = HashSet::new();
let mut field_set = BTreeSet::new();
for field in project_2.schema().fields() {
assert!(field_set.insert(field.qualified_name()));
}
Expand All @@ -779,7 +779,7 @@ mod test {
.unwrap()
.build()
.unwrap();
let affected_id: HashSet<Identifier> =
let affected_id: BTreeSet<Identifier> =
["c+a".to_string(), "d+a".to_string()].into_iter().collect();
let expr_set = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
Expand All @@ -790,7 +790,7 @@ mod test {
let project = build_project_plan(join, affected_id.clone(), &expr_set).unwrap();
let project_2 = build_project_plan(project, affected_id, &expr_set).unwrap();

let mut field_set = HashSet::new();
let mut field_set = BTreeSet::new();
for field in project_2.schema().fields() {
assert!(field_set.insert(field.qualified_name()));
}
Expand Down
78 changes: 67 additions & 11 deletions datafusion/optimizer/src/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {

/// returns the contained boolean value in `expr` as
/// `Expr::Literal(ScalarValue::Boolean(v))`.
///
/// panics if expr is not a literal boolean
fn as_bool_lit(expr: Expr) -> Result<Option<bool>> {
match expr {
Expr::Literal(ScalarValue::Boolean(v)) => Ok(v),
Expand Down Expand Up @@ -502,7 +500,7 @@ impl<'a> ConstEvaluator<'a> {
ColumnarValue::Array(a) => {
if a.len() != 1 {
Err(DataFusionError::Execution(format!(
"Could not evaluate the expressison, found a result of length {}",
"Could not evaluate the expression, found a result of length {}",
a.len()
)))
} else {
Expand Down Expand Up @@ -803,6 +801,27 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
out_expr.rewrite(self)?
}

//
// Rules for Between
//

// a between 3 and 5 --> a >= 3 AND a <=5
// a not between 3 and 5 --> a < 3 OR a > 5
Between {
expr,
low,
high,
negated,
} => {
if negated {
let l = *expr.clone();
let r = *expr;
or(l.lt(*low), r.gt(*high))
} else {
and(expr.clone().gt_eq(*low), expr.lt_eq(*high))
}
}

expr => {
// no additional rewrites possible
expr
Expand Down Expand Up @@ -1555,8 +1574,13 @@ mod tests {
high: Box::new(lit(10)),
};
let expr = expr.or(lit_bool_null());
let result = simplify(expr.clone());
assert_eq!(expr, result);
let result = simplify(expr);

let expected_expr = or(
and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
lit_bool_null(),
);
assert_eq!(expected_expr, result);
}

#[test]
Expand All @@ -1579,17 +1603,49 @@ mod tests {
assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);

// c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
// it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10`
// and should not be rewritten
// it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)`
// and the Boolean(NULL) should remain
let expr = Expr::Between {
expr: Box::new(col("c1")),
negated: false,
low: Box::new(lit(0)),
high: Box::new(lit(10)),
};
let expr = expr.and(lit_bool_null());
let result = simplify(expr.clone());
assert_eq!(expr, result);
let result = simplify(expr);

let expected_expr = and(
and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
lit_bool_null(),
);
assert_eq!(expected_expr, result);
}

#[test]
fn simplify_expr_between() {
// c2 between 3 and 4 is c2 >= 3 and c2 <= 4
let expr = Expr::Between {
expr: Box::new(col("c2")),
negated: false,
low: Box::new(lit(3)),
high: Box::new(lit(4)),
};
assert_eq!(
simplify(expr),
and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
);

// c2 not between 3 and 4 is c2 < 3 or c2 > 4
let expr = Expr::Between {
expr: Box::new(col("c2")),
negated: true,
low: Box::new(lit(3)),
high: Box::new(lit(4)),
};
assert_eq!(
simplify(expr),
or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
);
}

// ------------------------------
Expand Down Expand Up @@ -2167,7 +2223,7 @@ mod tests {
.unwrap()
.build()
.unwrap();
let expected = "Filter: #test.d NOT BETWEEN Int32(1) AND Int32(10) AS NOT test.d BETWEEN Int32(1) AND Int32(10)\
let expected = "Filter: #test.d < Int32(1) OR #test.d > Int32(10) AS NOT test.d BETWEEN Int32(1) AND Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
Expand All @@ -2188,7 +2244,7 @@ mod tests {
.unwrap()
.build()
.unwrap();
let expected = "Filter: #test.d BETWEEN Int32(1) AND Int32(10) AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\
let expected = "Filter: #test.d >= Int32(1) AND #test.d <= Int32(10) AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
Expand Down