From d81894f0bb570e6646cbd8cba7de58d834001b79 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 10 Mar 2017 11:03:33 +0000 Subject: [PATCH 1/2] Support more expression canonicalization. --- .../catalyst/expressions/Canonicalize.scala | 118 ++++++++++++++++-- .../expressions/ExpressionSetSuite.scala | 47 +++++++ 2 files changed, 158 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 65e497afc12cd..02f9ff5781f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.types.DataType + /** * Rewrites an expression using rules that are guaranteed preserve the result while attempting * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization @@ -43,11 +45,11 @@ object Canonicalize extends { case _ => e } - /** Collects adjacent commutative operations. */ - private def gatherCommutative( + /** Collects adjacent operations. The operations can be non commutative. */ + private def gatherAdjacent( e: Expression, f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { - case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) + case c if f.isDefinedAt(c) => f(c).flatMap(gatherAdjacent(_, f)) case other => other :: Nil } @@ -55,12 +57,88 @@ object Canonicalize extends { private def orderCommutative( e: Expression, f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = - gatherCommutative(e, f).sortBy(_.hashCode()) + gatherAdjacent(e, f).sortBy(_.hashCode()) - /** Rearrange expressions that are commutative or associative. */ + /** Rearrange expressions that are commutative or associative or semantically equal. */ private def expressionReorder(e: Expression): Expression = e match { - case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) - case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + case UnaryMinus(UnaryMinus(c)) => c + + // If the expression is composed of `Add` and `Subtract`, we rearrange it by extracting all + // sub-expressions like: + // a + b => a, b + // a - b => a, -b + // -(a + b) => -a, -b + // -(a - b) => -a, b + // Then we concatenate those sub-expressions by: + // 1. Remove the pairs of sub-expressions like (b, -b). + // 2. Concatenate remainning sub-expressions with `Add`. + case a: Add => + // Extract sub-expressions. + val (positiveExprs, negativeExprs) = gatherAdjacent(a, { + case Add(l, r) => Seq(l, r) + case Subtract(l, r) => Seq(l, UnaryMinus(r)) + case UnaryMinus(Add(l, r)) => Seq(UnaryMinus(l), UnaryMinus(r)) + case UnaryMinus(Subtract(l, r)) => Seq(UnaryMinus(l), r) + }).map { e => + e.transform { case UnaryMinus(UnaryMinus(c)) => c } + }.filter { + case Literal(0, _) => false + case UnaryMinus(Literal(0, _)) => false + case _ => true + }.partition(!_.isInstanceOf[UnaryMinus]) + + // Remove the pairs of sub-expressions like (b, -b). + val (newLeftExprs, newRightExprs) = filterOutExprs(positiveExprs, + negativeExprs.map(_.asInstanceOf[UnaryMinus].child)) + + val finalExprs = (newLeftExprs ++ newRightExprs.map(UnaryMinus(_))).sortBy(_.hashCode()) + if (finalExprs.isEmpty) { + Literal(0, a.dataType) + } else { + finalExprs.reduce(Add) + } + + case Subtract(sl, sr) => expressionReorder(Add(sl, UnaryMinus(sr))) + + // If the expression is composed of `Multiply` and `Divide`, we rearrange it by extracting all + // sub-expressions like: + // a * b => a, b + // a / b => a, 1 / b + // 1 / (a * b) => 1 / a, 1 / b + // 1 / (a / b) => 1 / a, b + // Then we concatenate those sub-expressions by: + // 1. Remove the pairs of sub-expressions like (b, 1 / b). + // 2. Concatenate remainning sub-expressions with `Multiply` and `Divide`. + case m: Multiply => + // Extract sub-expressions. + val (multiplyExprs, reciprocalExprs) = gatherAdjacent(m, { + case Multiply(l, r) => Seq(l, r) + case Divide(l, r) => Seq(l, UnaryReciprocal(r)) + case UnaryReciprocal(Multiply(l, r)) => Seq(UnaryReciprocal(l), UnaryReciprocal(r)) + case UnaryReciprocal(Divide(l, r)) => Seq(UnaryReciprocal(l), r) + }).map { e => + e.transform { case UnaryReciprocal(UnaryReciprocal(c)) => c } + }.filter { + case Literal(1, _) => false + case UnaryReciprocal(Literal(1, _)) => false + case _ => true + }.partition(!_.isInstanceOf[UnaryReciprocal]) + + // Remove the pairs of sub-expressions like (b, 1 / b). + val (newLeftExprs, newRightExprs) = filterOutExprs(multiplyExprs, + reciprocalExprs.map(_.asInstanceOf[UnaryReciprocal].child)) + + val finalExprs = (newLeftExprs ++ newRightExprs.map(UnaryReciprocal(_))).sortBy(_.hashCode()) + if (finalExprs.isEmpty) { + Literal(1, m.dataType) + } else { + finalExprs.map { + case u: UnaryReciprocal => Divide(Literal(1, u.dataType), u.child) + case other => other + }.reduce(Multiply) + } + + case Divide(dl, dr) => expressionReorder(Multiply(dl, UnaryReciprocal(dr))) case o: Or => orderCommutative(o, { case Or(l, r) if l.deterministic && r.deterministic => Seq(l, r) }) @@ -87,4 +165,30 @@ object Canonicalize extends { case _ => e } + + /** Finds the expressions existing in both set of expressions and drops them from two set. */ + private def filterOutExprs( + leftExprs: Seq[Expression], + rightExprs: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + var newLeftExprs = leftExprs + val foundIndexes = rightExprs.zipWithIndex.map { case (r, rIndex) => + val found = newLeftExprs.indexWhere(_.semanticEquals(r)) + if (found >= 0) { + newLeftExprs = newLeftExprs.slice(0, found) ++ + newLeftExprs.slice(found + 1, newLeftExprs.length) + } + (found, rIndex) + } + val dropRightIndexes = foundIndexes.filter(_._1 >= 0).unzip._2 + val newRightExprs = rightExprs.zipWithIndex.filterNot { case (r, index) => + dropRightIndexes.contains(index) + }.unzip._1 + (newLeftExprs, newRightExprs) + } +} + +/** A private [[UnaryExpression]] only used in expression canonicalization. */ +private[expressions] case class UnaryReciprocal(child: Expression) + extends UnaryExpression with Unevaluable { + override def dataType: DataType = child.dataType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index d617ad540d5ff..b9359ba95b9ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -198,6 +198,53 @@ class ExpressionSetSuite extends SparkFunSuite { Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), Rand(1L) > aUpper || (aUpper > bUpper && aUpper <= Rand(1L)) || (aUpper > 10 && bUpper > 10)) + // Canonicalize the expressions combined of `Add` and `Subtract`. + setTest(1, + (aUpper + aLower + bUpper) - (aLower + bUpper), + aUpper) + setTest(1, + (-aUpper + bUpper - aLower) - (-aUpper + bUpper), + bUpper - aUpper - aLower + aUpper - bUpper, + (bUpper + aUpper) - (aUpper + aLower + bUpper)) + setTest(1, + -(-aUpper - aLower + bUpper) - (aUpper - aUpper + bUpper), + -bUpper + aUpper + aLower - aUpper + aUpper - bUpper, + (-bUpper - bUpper + aUpper) + (aUpper + aLower - aUpper)) + + setTest(1, + aUpper + aLower - aLower, aUpper) + setTest(1, + aUpper + aLower + aUpper + bUpper + bLower - aLower, + aUpper + aUpper + bUpper + bLower) + setTest(1, + aUpper + (aLower + aUpper + bUpper) + bLower - aLower - (aUpper + bUpper), + aUpper + bLower) + setTest(1, + aUpper + aLower - aUpper - aLower - bUpper, + -bUpper) + setTest(1, + aUpper + aLower - aUpper - aLower, + 0) + + // Canonicalize the expressions composed of `Multiply` and `Divide`. + setTest(1, + aUpper * bLower / bLower, + aUpper) + setTest(1, + aUpper * bLower * bUpper * bUpper * aLower / bUpper, + aUpper * bLower * bUpper * aLower) + setTest(1, + (aUpper + bUpper) * (bLower * bUpper) * bUpper * aLower / bUpper / (bLower * bUpper), + (aUpper + bUpper) * aLower) + setTest(1, + aUpper * bLower * bUpper / aUpper / bLower / bUpper, + aUpper * bLower / aUpper * bUpper / bLower / bUpper, + Literal(1)) + setTest(1, + aUpper * bLower * bUpper / aUpper / bLower / bUpper / (aUpper + aLower), + aUpper / aUpper * bLower * bUpper / bLower / (aUpper + aLower) / bUpper, + Literal(1) / (aUpper + aLower)) + test("add to / remove from set") { val initialSet = ExpressionSet(aUpper + 1 :: Nil) From 13e236d42bb3668adf312021b17034c3b1d44161 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 16 Mar 2017 07:35:24 +0000 Subject: [PATCH 2/2] Move the change to optimizer. --- .../catalyst/expressions/Canonicalize.scala | 117 +---------- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/expressions.scala | 181 ++++++++++++++++-- .../expressions/ExpressionSetSuite.scala | 47 ----- .../SimplifyAssociativeOperatorSuite.scala | 74 +++++++ .../spark/sql/ColumnExpressionSuite.scala | 7 +- 6 files changed, 249 insertions(+), 178 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyAssociativeOperatorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 02f9ff5781f49..71531221881a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.DataType - /** * Rewrites an expression using rules that are guaranteed preserve the result while attempting * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization @@ -45,8 +43,11 @@ object Canonicalize extends { case _ => e } - /** Collects adjacent operations. The operations can be non commutative. */ - private def gatherAdjacent( + /** + * Collects adjacent operations. The operations can be non commutative. + * This is not private because the optimizer uses this too. + */ + def gatherAdjacent( e: Expression, f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { case c if f.isDefinedAt(c) => f(c).flatMap(gatherAdjacent(_, f)) @@ -59,86 +60,10 @@ object Canonicalize extends { f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = gatherAdjacent(e, f).sortBy(_.hashCode()) - /** Rearrange expressions that are commutative or associative or semantically equal. */ + /** Rearrange expressions that are commutative or associative. */ private def expressionReorder(e: Expression): Expression = e match { - case UnaryMinus(UnaryMinus(c)) => c - - // If the expression is composed of `Add` and `Subtract`, we rearrange it by extracting all - // sub-expressions like: - // a + b => a, b - // a - b => a, -b - // -(a + b) => -a, -b - // -(a - b) => -a, b - // Then we concatenate those sub-expressions by: - // 1. Remove the pairs of sub-expressions like (b, -b). - // 2. Concatenate remainning sub-expressions with `Add`. - case a: Add => - // Extract sub-expressions. - val (positiveExprs, negativeExprs) = gatherAdjacent(a, { - case Add(l, r) => Seq(l, r) - case Subtract(l, r) => Seq(l, UnaryMinus(r)) - case UnaryMinus(Add(l, r)) => Seq(UnaryMinus(l), UnaryMinus(r)) - case UnaryMinus(Subtract(l, r)) => Seq(UnaryMinus(l), r) - }).map { e => - e.transform { case UnaryMinus(UnaryMinus(c)) => c } - }.filter { - case Literal(0, _) => false - case UnaryMinus(Literal(0, _)) => false - case _ => true - }.partition(!_.isInstanceOf[UnaryMinus]) - - // Remove the pairs of sub-expressions like (b, -b). - val (newLeftExprs, newRightExprs) = filterOutExprs(positiveExprs, - negativeExprs.map(_.asInstanceOf[UnaryMinus].child)) - - val finalExprs = (newLeftExprs ++ newRightExprs.map(UnaryMinus(_))).sortBy(_.hashCode()) - if (finalExprs.isEmpty) { - Literal(0, a.dataType) - } else { - finalExprs.reduce(Add) - } - - case Subtract(sl, sr) => expressionReorder(Add(sl, UnaryMinus(sr))) - - // If the expression is composed of `Multiply` and `Divide`, we rearrange it by extracting all - // sub-expressions like: - // a * b => a, b - // a / b => a, 1 / b - // 1 / (a * b) => 1 / a, 1 / b - // 1 / (a / b) => 1 / a, b - // Then we concatenate those sub-expressions by: - // 1. Remove the pairs of sub-expressions like (b, 1 / b). - // 2. Concatenate remainning sub-expressions with `Multiply` and `Divide`. - case m: Multiply => - // Extract sub-expressions. - val (multiplyExprs, reciprocalExprs) = gatherAdjacent(m, { - case Multiply(l, r) => Seq(l, r) - case Divide(l, r) => Seq(l, UnaryReciprocal(r)) - case UnaryReciprocal(Multiply(l, r)) => Seq(UnaryReciprocal(l), UnaryReciprocal(r)) - case UnaryReciprocal(Divide(l, r)) => Seq(UnaryReciprocal(l), r) - }).map { e => - e.transform { case UnaryReciprocal(UnaryReciprocal(c)) => c } - }.filter { - case Literal(1, _) => false - case UnaryReciprocal(Literal(1, _)) => false - case _ => true - }.partition(!_.isInstanceOf[UnaryReciprocal]) - - // Remove the pairs of sub-expressions like (b, 1 / b). - val (newLeftExprs, newRightExprs) = filterOutExprs(multiplyExprs, - reciprocalExprs.map(_.asInstanceOf[UnaryReciprocal].child)) - - val finalExprs = (newLeftExprs ++ newRightExprs.map(UnaryReciprocal(_))).sortBy(_.hashCode()) - if (finalExprs.isEmpty) { - Literal(1, m.dataType) - } else { - finalExprs.map { - case u: UnaryReciprocal => Divide(Literal(1, u.dataType), u.child) - case other => other - }.reduce(Multiply) - } - - case Divide(dl, dr) => expressionReorder(Multiply(dl, UnaryReciprocal(dr))) + case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) + case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) case o: Or => orderCommutative(o, { case Or(l, r) if l.deterministic && r.deterministic => Seq(l, r) }) @@ -165,30 +90,4 @@ object Canonicalize extends { case _ => e } - - /** Finds the expressions existing in both set of expressions and drops them from two set. */ - private def filterOutExprs( - leftExprs: Seq[Expression], - rightExprs: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - var newLeftExprs = leftExprs - val foundIndexes = rightExprs.zipWithIndex.map { case (r, rIndex) => - val found = newLeftExprs.indexWhere(_.semanticEquals(r)) - if (found >= 0) { - newLeftExprs = newLeftExprs.slice(0, found) ++ - newLeftExprs.slice(found + 1, newLeftExprs.length) - } - (found, rIndex) - } - val dropRightIndexes = foundIndexes.filter(_._1 >= 0).unzip._2 - val newRightExprs = rightExprs.zipWithIndex.filterNot { case (r, index) => - dropRightIndexes.contains(index) - }.unzip._1 - (newLeftExprs, newRightExprs) - } -} - -/** A private [[UnaryExpression]] only used in expression canonicalization. */ -private[expressions] case class UnaryReciprocal(child: Expression) - extends UnaryExpression with Unevaluable { - override def dataType: DataType = child.dataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c8ed4190a13ad..6ae07aeb20576 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -102,6 +102,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) OptimizeIn(conf), ConstantFolding, ReorderAssociativeOperator, + SimplifyAssociativeOperator, LikeSimplification, BooleanSimplification, SimplifyConditionals, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 21d1cd5932620..e0f78a1d1b8b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -52,10 +52,150 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** A private [[UnaryExpression]] only used in expression optimization. */ +private[optimizer] case class UnaryReciprocal(child: Expression) + extends UnaryExpression with Unevaluable { + override def dataType: DataType = child.dataType +} /** - * Reorder associative integral-type operators and fold all constants into one. + * If the expressions are comprised of sub-expressions which can cancel out each other, e.g., + * the pair of (b, -b) in a chain of `Add`, this rule also simplifies the expressions by removing + * such pairs of sub-expressions. */ +object SimplifyAssociativeOperator extends Rule[LogicalPlan] { + /** Finds the expressions existing in both set of expressions and drops them from two set. */ + private def filterOutExprs( + leftExprs: Seq[Expression], + rightExprs: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + var newLeftExprs = leftExprs + val foundIndexes = rightExprs.zipWithIndex.map { case (r, rIndex) => + val (rExpr, rDataType) = r match { + case Cast(wrappedR, dt, _) => (wrappedR, dt) + case _ => (r, r.dataType) + } + val found = newLeftExprs.indexWhere { + case Cast(wrapped, dt, _) if dt.sameType(rDataType) => wrapped.semanticEquals(rExpr) + case other if other.dataType.sameType(rDataType) => other.semanticEquals(rExpr) + case _ => false + } + if (found >= 0) { + newLeftExprs = newLeftExprs.slice(0, found) ++ + newLeftExprs.slice(found + 1, newLeftExprs.length) + } + (found, rIndex) + } + val dropRightIndexes = foundIndexes.filter(_._1 >= 0).unzip._2 + val newRightExprs = rightExprs.zipWithIndex.filterNot { case (r, index) => + dropRightIndexes.contains(index) + }.unzip._1 + (newLeftExprs, newRightExprs) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // We skip `Aggregate` when simplifying associative expressions. That is because we will + // traverse up expressions and replace them with simplified expressions. If the aggregate + // expressions which exists in grouping expressions are simplified, the optimized + // expression could not be derived from grouping expressions. + case q: LogicalPlan if !q.isInstanceOf[Aggregate] => + q transformExpressionsUp { + case e => reorderExpr(e) + } + } + + private def reorderExpr(expr: Expression): Expression = expr match { + case u @ UnaryMinus(UnaryMinus(c)) if u.deterministic => c + + // If the expression is composed of `Add` and `Subtract`, we rearrange it by extracting all + // sub-expressions like: + // a + b => a, b + // a - b => a, -b + // -(a + b) => -a, -b + // -(a - b) => -a, b + // Then we concatenate those sub-expressions by: + // 1. Remove the pairs of sub-expressions like (b, -b). + // 2. Concatenate remainning sub-expressions with `Add`. + case a: Add if a.deterministic => + // Extract sub-expressions. + val (positiveExprs, negativeExprs) = Canonicalize.gatherAdjacent(a, { + case Add(l, r) => Seq(l, r) + case Subtract(l, r) => Seq(l, UnaryMinus(r)) + case UnaryMinus(Add(l, r)) => Seq(UnaryMinus(l), UnaryMinus(r)) + case UnaryMinus(Subtract(l, r)) => Seq(UnaryMinus(l), r) + }).map { e => + e.transform { case UnaryMinus(UnaryMinus(c)) => c } + }.filter { + // Remove any +0, -0 + case Literal(0, _) => false + case UnaryMinus(Literal(0, _)) => false + case Cast(Literal(0, _), _, _) => false + case UnaryMinus(Cast(Literal(0, _), _, _)) => false + case _ => true + }.partition(!_.isInstanceOf[UnaryMinus]) + + // Remove the pairs of sub-expressions like (b, -b). + val (newLeftExprs, newRightExprs) = filterOutExprs(positiveExprs, + negativeExprs.map(_.asInstanceOf[UnaryMinus].child)) + + val finalExprs = (newLeftExprs ++ newRightExprs.map(UnaryMinus(_))) + if (finalExprs.isEmpty) { + Cast(Literal(0), a.dataType) + } else { + finalExprs.reduce(Add) + } + + case s @ Subtract(sl, sr) if s.deterministic => reorderExpr(Add(sl, UnaryMinus(sr))) + + // If the expression is composed of `Multiply` and `Divide`, we rearrange it by extracting + // all sub-expressions like: + // a * b => a, b + // a / b => a, 1 / b + // 1 / (a * b) => 1 / a, 1 / b + // 1 / (a / b) => 1 / a, b + // Then we concatenate those sub-expressions by: + // 1. Remove the pairs of sub-expressions like (b, 1 / b). + // 2. Concatenate remainning sub-expressions with `Multiply` and `Divide`. + case m: Multiply if m.deterministic => + // Extract sub-expressions. + val (multiplyExprs, reciprocalExprs) = Canonicalize.gatherAdjacent(m, { + case Multiply(l, r) => Seq(l, r) + case Divide(l, r) => Seq(l, UnaryReciprocal(r)) + case UnaryReciprocal(Multiply(l, r)) => Seq(UnaryReciprocal(l), UnaryReciprocal(r)) + case UnaryReciprocal(Divide(l, r)) => Seq(UnaryReciprocal(l), r) + }).map { e => + e.transform { case UnaryReciprocal(UnaryReciprocal(c)) => c } + }.filter { + // Remove any +1, -1 + case Literal(1, _) => false + case UnaryReciprocal(Literal(1, _)) => false + case Cast(Literal(1, _), _, _) => false + case UnaryReciprocal(Cast(Literal(1, _), _, _)) => false + case _ => true + }.partition(!_.isInstanceOf[UnaryReciprocal]) + + // Remove the pairs of sub-expressions like (b, 1 / b). + val (newLeftExprs, newRightExprs) = filterOutExprs(multiplyExprs, + reciprocalExprs.map(_.asInstanceOf[UnaryReciprocal].child)) + + val finalExprs = (newLeftExprs ++ newRightExprs.map(UnaryReciprocal(_))) + if (finalExprs.isEmpty) { + Literal(1, m.dataType) + } else { + finalExprs.reduceLeft { (resultExpr, expr) => + expr match { + case u: UnaryReciprocal => Divide(resultExpr, u.child) + case other => Multiply(resultExpr, expr) + } + } + } + + case d @ Divide(dl, dr) if d.deterministic => reorderExpr(Multiply(dl, UnaryReciprocal(dr))) + + case _ => expr + } +} + +/** Reorder associative integral-type operators and fold all constants into one. */ object ReorderAssociativeOperator extends Rule[LogicalPlan] { private def flattenAdd( expression: Expression, @@ -86,29 +226,28 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { // grouping expressions. val groupingExpressionSet = collectGroupingExpressions(q) q transformExpressionsDown { - case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable) - if (foldables.size > 1) { - val foldableExpr = foldables.reduce((x, y) => Add(x, y)) - val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType) - if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c) - } else { - a - } - case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] => - val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable) - if (foldables.size > 1) { - val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) - val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType) - if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c) - } else { - m - } - } + case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] => + val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable) + if (foldables.size > 1) { + val foldableExpr = foldables.reduce((x, y) => Add(x, y)) + val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType) + if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y)), c) + } else { + a + } + case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] => + val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable) + if (foldables.size > 1) { + val foldableExpr = foldables.reduce((x, y) => Multiply(x, y)) + val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType) + if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y)), c) + } else { + m + } + } } } - /** * Optimize IN predicates: * 1. Removes literal repetitions. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index b9359ba95b9ce..d617ad540d5ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -198,53 +198,6 @@ class ExpressionSetSuite extends SparkFunSuite { Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), Rand(1L) > aUpper || (aUpper > bUpper && aUpper <= Rand(1L)) || (aUpper > 10 && bUpper > 10)) - // Canonicalize the expressions combined of `Add` and `Subtract`. - setTest(1, - (aUpper + aLower + bUpper) - (aLower + bUpper), - aUpper) - setTest(1, - (-aUpper + bUpper - aLower) - (-aUpper + bUpper), - bUpper - aUpper - aLower + aUpper - bUpper, - (bUpper + aUpper) - (aUpper + aLower + bUpper)) - setTest(1, - -(-aUpper - aLower + bUpper) - (aUpper - aUpper + bUpper), - -bUpper + aUpper + aLower - aUpper + aUpper - bUpper, - (-bUpper - bUpper + aUpper) + (aUpper + aLower - aUpper)) - - setTest(1, - aUpper + aLower - aLower, aUpper) - setTest(1, - aUpper + aLower + aUpper + bUpper + bLower - aLower, - aUpper + aUpper + bUpper + bLower) - setTest(1, - aUpper + (aLower + aUpper + bUpper) + bLower - aLower - (aUpper + bUpper), - aUpper + bLower) - setTest(1, - aUpper + aLower - aUpper - aLower - bUpper, - -bUpper) - setTest(1, - aUpper + aLower - aUpper - aLower, - 0) - - // Canonicalize the expressions composed of `Multiply` and `Divide`. - setTest(1, - aUpper * bLower / bLower, - aUpper) - setTest(1, - aUpper * bLower * bUpper * bUpper * aLower / bUpper, - aUpper * bLower * bUpper * aLower) - setTest(1, - (aUpper + bUpper) * (bLower * bUpper) * bUpper * aLower / bUpper / (bLower * bUpper), - (aUpper + bUpper) * aLower) - setTest(1, - aUpper * bLower * bUpper / aUpper / bLower / bUpper, - aUpper * bLower / aUpper * bUpper / bLower / bUpper, - Literal(1)) - setTest(1, - aUpper * bLower * bUpper / aUpper / bLower / bUpper / (aUpper + aLower), - aUpper / aUpper * bLower * bUpper / bLower / (aUpper + aLower) / bUpper, - Literal(1) / (aUpper + aLower)) - test("add to / remove from set") { val initialSet = ExpressionSet(aUpper + 1 :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyAssociativeOperatorSuite.scala new file mode 100644 index 0000000000000..db02a5446c3f5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyAssociativeOperatorSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.DoubleType + +class SimplifyAssociativeOperatorSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("SimplifyAssociativeOperator", FixedPoint(100), + SimplifyAssociativeOperator) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.double, 'c.int) + + test("Reorder associative operators") { + val originalQuery = + testRelation + .select( + ((Literal(3) + ((Literal(1) + 'a) - 3)) - 1).as("a1"), + ((Literal(2.0) + 'b) * 2.0 / 2.0 / 3.0 * 3.0).as("b1"), + (('c + 1) / ('b + 2) * ('b + 2 - 'a + 'a)).as("c1"), + Rand(0) * 1 * 2 * 3 * 4) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + ('a).as("a1"), + (Literal(2.0) + 'b).as("b1"), + (Cast('c + 1, DoubleType)).as("c1"), + Rand(0) * 1 * 2 * 3 * 4) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("nested expression with aggregate operator") { + val originalQuery = + testRelation.as("t1") + .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr)) + .groupBy("t1.a".attr + "t1.a".attr - "t1.a".attr, "t2.a".attr + "t2.a".attr - "t2.a".attr)( + ("t1.a".attr + "t1.a".attr - "t1.a".attr).as("col")) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b0f398dab7455..e52fa930927d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -666,8 +666,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)] // Project [key, Rand 5 AS rand] // LogicalRDD [key, value] + // + // SPARK-19902: Above query plan is further optimized to ... + // Project [2.0 AS (rand1 - rand2)] + // LogicalRDD [key, value] + // Because (rand + 1 AS rand1) - (rand - 1 AS rand2) is actually 2. val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2") - checkNumProjects(dfWithThreeProjects, 2) + checkNumProjects(dfWithThreeProjects, 1) dfWithThreeProjects.collect().foreach { row => assert(row.getDouble(0) === 2.0 +- 0.0001) }