From 2a4f4de5f27cda3e5ad1bb899e45f06b6560ab4d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 23 Feb 2016 16:46:31 -0800 Subject: [PATCH 1/2] [SPARK-13092][SQL] Add ExpressionSet for constraint tracking --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../catalyst/expressions/Canonicalize.scala | 81 +++++++++++++++++ .../sql/catalyst/expressions/Expression.scala | 56 ++++-------- .../catalyst/expressions/ExpressionSet.scala | 87 +++++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 12 +-- .../expressions/ExpressionSetSuite.scala | 85 ++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 3 +- 7 files changed, 281 insertions(+), 45 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1a2ec7ed931ea..a12f7396fe819 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -235,7 +235,7 @@ package object dsl { implicit class DslAttribute(a: AttributeReference) { def notNull: AttributeReference = a.withNullability(false) - def nullable: AttributeReference = a.withNullability(true) + def canBeNull: AttributeReference = a.withNullability(true) def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala new file mode 100644 index 0000000000000..b58a5273041e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.rules._ + +/** + * Rewrites an expression using rules that are guaranteed preserve the result while attempting + * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization + * will always return the same answer given the same input (i.e. false positives should not be + * possible). However, it is possible that two canonical expressions that are not equal will in fact + * return the same answer given any input (i.e. false negatives are possible). + * + * The following rules are applied: + * - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped. + * - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered + * by `hashCode`. +* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. + * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + */ +object Canonicalize extends RuleExecutor[Expression] { + override protected def batches: Seq[Batch] = + Batch( + "Expression Canonicalization", FixedPoint(100), + IgnoreNamesTypes, + Reorder) :: Nil + + /** Remove names and nullability from types. */ + protected object IgnoreNamesTypes extends Rule[Expression] { + override def apply(e: Expression): Expression = e transformUp { + case a: AttributeReference => + AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) + } + } + + /** Collects adjacent commutative operations. */ + protected def gatherCommutative( + e: Expression, + f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match { + case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) + case other => other :: Nil + } + + /** Orders a set of commutative operations by their hash code. */ + protected def orderCommutative( + e: Expression, + f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = + gatherCommutative(e, f).sortBy(_.hashCode()) + + /** Rearrange expressions that are commutative or associative. */ + protected object Reorder extends Rule[Expression] { + override def apply(e: Expression): Expression = e transformUp { + case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add) + case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply) + + case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l) + case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l) + + case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l) + case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + + case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 119496c7eee5e..692c16092fe3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -144,49 +144,32 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) + /** + * Returns an expression where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, etc.) See [[Canonicalize]] for more details. + * + * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always + * evaluate to the same result. + */ + lazy val canonicalized: Expression = Canonicalize.execute(this) + /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). + * + * See [[Canonicalize]] for more details. */ - def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { - def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 - case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) - case (i1, i2) => i1 == i2 - } - } - // Non-deterministic expressions cannot be semantic equal - if (!deterministic || !other.deterministic) return false - val elements1 = this.productIterator.toSeq - val elements2 = other.asInstanceOf[Product].productIterator.toSeq - checkSemantic(elements1, elements2) - } + def semanticEquals(other: Expression): Boolean = + deterministic && other.deterministic && canonicalized == other.canonicalized /** - * Returns the hash for this expression. Expressions that compute the same result, even if - * they differ cosmetically should return the same hash. + * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + * + * See [[Canonicalize]] for more details. */ - def semanticHash() : Int = { - def computeHash(e: Seq[Any]): Int = { - // See http://stackoverflow.com/questions/113511/hash-code-implementation - var hash: Int = 17 - e.foreach(i => { - val h: Int = i match { - case e: Expression => e.semanticHash() - case Some(e: Expression) => e.semanticHash() - case t: Traversable[_] => computeHash(t.toSeq) - case null => 0 - case other => other.hashCode() - } - hash = hash * 37 + h - }) - hash - } - - computeHash(this.productIterator.toSeq) - } + def semanticHash(): Int = canonicalized.hashCode() /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, @@ -369,7 +352,6 @@ abstract class UnaryExpression extends Expression { } } - /** * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala new file mode 100644 index 0000000000000..acea049adca3d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +object ExpressionSet { + /** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */ + def apply(expressions: TraversableOnce[Expression]): ExpressionSet = { + val set = new ExpressionSet() + expressions.foreach(set.add) + set + } +} + +/** + * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]] + * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details. + * + * Internally this set uses the canonical representation, but keeps also track of the original + * expressions to ease debugging. Since different expressions can share the same canonical + * representation, this means that operations that extract expressions from this set are only + * guranteed to see at least one such expression. For example: + * + * {{{ + * val set = AttributeSet(a + 1, 1 + a) + * + * set.iterator => Iterator(a + 1) + * set.contains(a + 1) => true + * set.contains(1 + a) => true + * set.contains(a + 2) => false + * }}} + */ +class ExpressionSet protected( + protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, + protected val originals: mutable.Buffer[Expression] = new ArrayBuffer) + extends Set[Expression] { + + protected def add(e: Expression): Unit = { + if (!baseSet.contains(e.canonicalized)) { + baseSet.add(e.canonicalized) + originals.append(e) + } + } + + override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized) + + override def +(elem: Expression): ExpressionSet = { + val newSet = new ExpressionSet(baseSet.clone(), originals.clone()) + newSet.add(elem) + newSet + } + + override def -(elem: Expression): ExpressionSet = { + val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) + val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) + new ExpressionSet(newBaseSet, newOriginals) + } + + override def iterator: Iterator[Expression] = originals.iterator + + /** + * Returns a string containing both the post [[Canonicalize]] expressions and the original + * expressions in this set. + */ + def toDebugString: String = + s""" + |baseSet: ${baseSet.mkString(", ")} + |originals: ${originals.mkString(", ")} + """.stripMargin +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 5e7d144ae4107..a74b288cb22ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -63,17 +63,19 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } /** - * A sequence of expressions that describes the data property of the output rows of this - * operator. For example, if the output of this operator is column `a`, an example `constraints` - * can be `Set(a > 10, a < 20)`. + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For + * example, if this set contains the expression `a = 2` then that expression is guaranteed to + * evaluate to `true` for all rows produced. */ - lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints) + lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints * based on the given operator's constraint propagation logic. These constraints are then * canonicalized and filtered automatically to contain only those attributes that appear in the - * [[outputSet]] + * [[outputSet]]. + * + * See [[Canonicalize]] for more details. */ protected def validConstraints: Set[Expression] = Set.empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala new file mode 100644 index 0000000000000..e81c26de4d32f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.IntegerType + +class ExpressionSetSuite extends SparkFunSuite { + + val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) + val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3)) + + val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2)) + val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + + val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + + def setTest(size: Int, exprs: Expression*): Unit = { + test(s"expect $size: ${exprs.mkString(", ")}") { + val set = ExpressionSet(exprs) + if (set.size != size) { + fail(set.toDebugString) + } + } + } + + def setTestIgnore(size: Int, exprs: Expression*): Unit = + ignore(s"expect $size: ${exprs.mkString(", ")}") {} + + // Commutative + setTest(1, aUpper + 1, aLower + 1) + setTest(2, aUpper + 1, aLower + 2) + setTest(2, aUpper + 1, fakeA + 1) + setTest(2, aUpper + 1, bUpper + 1) + + setTest(1, aUpper + aLower, aLower + aUpper) + setTest(1, aUpper + bUpper, bUpper + aUpper) + setTest(1, + aUpper + bUpper + 3, + bUpper + 3 + aUpper, + bUpper + aUpper + 3, + Literal(3) + aUpper + bUpper) + setTest(1, + aUpper * bUpper * 3, + bUpper * 3 * aUpper, + bUpper * aUpper * 3, + Literal(3) * aUpper * bUpper) + setTest(1, aUpper === bUpper, bUpper === aUpper) + + setTest(1, aUpper + 1 === bUpper, bUpper === Literal(1) + aUpper) + + + // Not commutative + setTest(2, aUpper - bUpper, bUpper - aUpper) + + // Reversable + setTest(1, aUpper > bUpper, bUpper < aUpper) + setTest(1, aUpper >= bUpper, bUpper <= aUpper) + + test("add to / remove from set") { + val initialSet = ExpressionSet(aUpper + 1 :: Nil) + + assert((initialSet + (aUpper + 1)).size == 1) + assert((initialSet + (aUpper + 2)).size == 2) + assert((initialSet - (aUpper + 1)).size == 0) + assert((initialSet - (aUpper + 2)).size == 1) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1d9db27e097ee..13ff4a2c419e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1980,9 +1980,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { verifyCallCount( df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) - // Would be nice if semantic equals for `+` understood commutative verifyCallCount( - df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) // Try disabling it via configuration. sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") From a255d957a400ecbef113e6f74e385786db0cd5a7 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 24 Feb 2016 11:09:57 -0800 Subject: [PATCH 2/2] more tests --- .../spark/sql/catalyst/expressions/ExpressionSetSuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index e81c26de4d32f..ce42e5784ccd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -81,5 +81,9 @@ class ExpressionSetSuite extends SparkFunSuite { assert((initialSet + (aUpper + 2)).size == 2) assert((initialSet - (aUpper + 1)).size == 0) assert((initialSet - (aUpper + 2)).size == 1) + + assert((initialSet + (aLower + 1)).size == 1) + assert((initialSet - (aLower + 1)).size == 0) + } }