diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index b58a5273041e4..ae1f6006135bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.rules._ - /** * Rewrites an expression using rules that are guaranteed preserve the result while attempting * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization @@ -30,26 +28,23 @@ import org.apache.spark.sql.catalyst.rules._ * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered * by `hashCode`. -* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. + * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. */ -object Canonicalize extends RuleExecutor[Expression] { - override protected def batches: Seq[Batch] = - Batch( - "Expression Canonicalization", FixedPoint(100), - IgnoreNamesTypes, - Reorder) :: Nil +object Canonicalize extends { + def execute(e: Expression): Expression = { + expressionReorder(ignoreNamesTypes(e)) + } /** Remove names and nullability from types. */ - protected object IgnoreNamesTypes extends Rule[Expression] { - override def apply(e: Expression): Expression = e transformUp { - case a: AttributeReference => - AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) - } + private def ignoreNamesTypes(e: Expression): Expression = e match { + case a: AttributeReference => + AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + case _ => e } /** Collects adjacent commutative operations. */ - protected def gatherCommutative( + private def gatherCommutative( e: Expression, f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) @@ -57,25 +52,25 @@ object Canonicalize extends RuleExecutor[Expression] { } /** Orders a set of commutative operations by their hash code. */ - protected def orderCommutative( + private def orderCommutative( e: Expression, f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = gatherCommutative(e, f).sortBy(_.hashCode()) /** Rearrange expressions that are commutative or associative. */ - protected object Reorder extends Rule[Expression] { - override def apply(e: Expression): Expression = e transformUp { - case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) - case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + private def expressionReorder(e: Expression): Expression = e match { + case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) + case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) + case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) - case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) - case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) + case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) + case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) - case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) - case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) - case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) - case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) - } + case _ => e } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 692c16092fe3f..16a1b2aee2730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -152,7 +152,10 @@ abstract class Expression extends TreeNode[Expression] { * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always * evaluate to the same result. */ - lazy val canonicalized: Expression = Canonicalize.execute(this) + lazy val canonicalized: Expression = { + val canonicalizedChildred = children.map(_.canonicalized) + Canonicalize.execute(withNewChildren(canonicalizedChildred)) + } /** * Returns true when two expressions will always compute the same result, even if they differ @@ -161,7 +164,7 @@ abstract class Expression extends TreeNode[Expression] { * See [[Canonicalize]] for more details. */ def semanticEquals(other: Expression): Boolean = - deterministic && other.deterministic && canonicalized == other.canonicalized + deterministic && other.deterministic && canonicalized == other.canonicalized /** * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard