diff --git a/benchmarks/expected-plans/q19.txt b/benchmarks/expected-plans/q19.txt index a8ce0e6c8527..552d743917dc 100644 --- a/benchmarks/expected-plans/q19.txt +++ b/benchmarks/expected-plans/q19.txt @@ -3,7 +3,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re Projection: lineitem.l_extendedprice, lineitem.l_discount Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) Inner Join: lineitem.l_partkey = part.p_partkey - Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") + Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode] - Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15) AND part.p_size >= Int32(1) - TableScan: part projection=[p_partkey, p_brand, p_size, p_container] \ No newline at end of file + Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) + TableScan: part projection=[p_partkey, p_brand, p_size, p_container] diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 0b309412c0c6..95a1ef0d73a1 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -1482,7 +1482,7 @@ mod tests { let expr = col("c1") .lt(lit(1)) .and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3)))); - let expected_expr = "c1_min < Int32(1) AND c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max"; + let expected_expr = "c1_min < Int32(1) AND (c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max)"; let predicate_expr = build_predicate_expression(&expr, &schema, &mut required_columns)?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); @@ -1561,7 +1561,9 @@ mod tests { list: vec![lit(1), lit(2), lit(3)], negated: true, }; - let expected_expr = "c1_min != Int32(1) OR Int32(1) != c1_max AND c1_min != Int32(2) OR Int32(2) != c1_max AND c1_min != Int32(3) OR Int32(3) != c1_max"; + let expected_expr = "(c1_min != Int32(1) OR Int32(1) != c1_max) \ + AND (c1_min != Int32(2) OR Int32(2) != c1_max) \ + AND (c1_min != Int32(3) OR Int32(3) != c1_max)"; let predicate_expr = build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); @@ -1633,7 +1635,10 @@ mod tests { ], negated: true, }; - let expected_expr = "CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64) AND CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64) AND CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64)"; + let expected_expr = + "(CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64)) \ + AND (CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64)) \ + AND (CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64))"; let predicate_expr = build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?; assert_eq!(format!("{:?}", predicate_expr), expected_expr); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 1d11245c3881..781d8ea53c21 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -31,7 +31,7 @@ use datafusion_common::Result; use datafusion_common::{plan_err, Column}; use datafusion_common::{DataFusionError, ScalarValue}; use std::fmt; -use std::fmt::Write; +use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; use std::ops::Not; use std::sync::Arc; @@ -265,6 +265,58 @@ impl BinaryExpr { pub fn new(left: Box, op: Operator, right: Box) -> Self { Self { left, op, right } } + + /// Get the operator precedence + /// use https://www.postgresql.org/docs/7.0/operators.htm#AEN2026 as a reference + pub fn precedence(&self) -> u8 { + match self.op { + Operator::Or => 5, + Operator::And => 10, + Operator::Like | Operator::NotLike => 19, + Operator::NotEq + | Operator::Eq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq => 20, + Operator::Plus | Operator::Minus => 30, + Operator::Multiply | Operator::Divide | Operator::Modulo => 40, + _ => 0, + } + } +} + +impl Display for BinaryExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // Put parentheses around child binary expressions so that we can see the difference + // between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed, + // based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are + // equivalent and the parentheses are not necessary. + + fn write_child( + f: &mut Formatter<'_>, + expr: &Expr, + precedence: u8, + ) -> fmt::Result { + match expr { + Expr::BinaryExpr(child) => { + let p = child.precedence(); + if p == 0 || p < precedence { + write!(f, "({})", child)?; + } else { + write!(f, "{}", child)?; + } + } + _ => write!(f, "{}", expr)?, + } + Ok(()) + } + + let precedence = self.precedence(); + write_child(f, self.left.as_ref(), precedence)?; + write!(f, " {} ", self.op)?; + write_child(f, self.right.as_ref(), precedence) + } } /// CASE expression @@ -717,9 +769,7 @@ impl fmt::Debug for Expr { negated: false, } => write!(f, "{:?} IN ({:?})", expr, subquery), Expr::ScalarSubquery(subquery) => write!(f, "({:?})", subquery), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - write!(f, "{:?} {} {:?}", left, op, right) - } + Expr::BinaryExpr(expr) => write!(f, "{}", expr), Expr::Sort { expr, asc, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index c0edba96fe0f..daa164aec269 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -619,7 +619,7 @@ mod test { )?; let expected = vec![ - (9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"), (3, "a + Int32(1)Int32(1)a"), @@ -671,8 +671,8 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * Int32(1) + test.c)]]\ - \n Projection: test.a * Int32(1) - test.b AS test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\ + let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\ + \n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\ \n TableScan: test"; assert_optimized_plan_eq(expected, &plan); diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs index 6396f1fbfd6c..148ae6715ddb 100644 --- a/datafusion/optimizer/src/filter_push_down.rs +++ b/datafusion/optimizer/src/filter_push_down.rs @@ -1044,7 +1044,7 @@ mod tests { let expected = "\ Projection: b * Int32(3) AS a, test.c\ \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n Filter: test.a * Int32(2) + test.c * Int32(3) = Int64(1)\ + \n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\ \n TableScan: test"; assert_optimized_plan_eq(&plan, expected); Ok(()) diff --git a/datafusion/optimizer/src/reduce_cross_join.rs b/datafusion/optimizer/src/reduce_cross_join.rs index fa7c18afd3e8..45230ebb243b 100644 --- a/datafusion/optimizer/src/reduce_cross_join.rs +++ b/datafusion/optimizer/src/reduce_cross_join.rs @@ -848,7 +848,7 @@ mod tests { .build()?; let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) AND t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", @@ -936,7 +936,7 @@ mod tests { .build()?; let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) AND t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", diff --git a/datafusion/optimizer/src/subquery_filter_to_join.rs b/datafusion/optimizer/src/subquery_filter_to_join.rs index 8ec6f3892890..29f51a42f4e1 100644 --- a/datafusion/optimizer/src/subquery_filter_to_join.rs +++ b/datafusion/optimizer/src/subquery_filter_to_join.rs @@ -352,7 +352,7 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) OR test.b IN () AND test.c IN () [a:UInt32, b:UInt32, c:UInt32]\ + \n Filter: (test.a = UInt32(1) OR test.b IN ()) AND test.c IN () [a:UInt32, b:UInt32, c:UInt32]\ \n Subquery: [c:UInt32]\ \n Projection: sq1.c [c:UInt32]\ \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 92264f06023f..214dbd6adf77 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -3413,7 +3413,7 @@ mod tests { #[test] fn select_binary_expr_nested() { let sql = "SELECT (age + salary)/2 from person"; - let expected = "Projection: person.age + person.salary / Int64(2)\ + let expected = "Projection: (person.age + person.salary) / Int64(2)\ \n TableScan: person"; quick_test(sql, expected); } @@ -3848,7 +3848,7 @@ mod tests { fn select_where_nullif_division() { let sql = "SELECT c3/(c4+c5) \ FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1"; - let expected = "Projection: aggregate_test_100.c3 / aggregate_test_100.c4 + aggregate_test_100.c5\ + let expected = "Projection: aggregate_test_100.c3 / (aggregate_test_100.c4 + aggregate_test_100.c5)\ \n Filter: aggregate_test_100.c3 / nullif(aggregate_test_100.c4 + aggregate_test_100.c5, Int64(0)) > Float64(0.1)\ \n TableScan: aggregate_test_100"; quick_test(sql, expected);