Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13092][SQL] Add ExpressionSet for constraint tracking #11338

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ package object dsl {

implicit class DslAttribute(a: AttributeReference) {
def notNull: AttributeReference = a.withNullability(false)
def nullable: AttributeReference = a.withNullability(true)
def canBeNull: AttributeReference = a.withNullability(true)
def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.rules._

/**
* Rewrites an expression using rules that are guaranteed preserve the result while attempting
* to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization
* will always return the same answer given the same input (i.e. false positives should not be
* possible). However, it is possible that two canonical expressions that are not equal will in fact
* return the same answer given any input (i.e. false negatives are possible).
*
* The following rules are applied:
* - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
* - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
* by `hashCode`.
* - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
* - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
*/
object Canonicalize extends RuleExecutor[Expression] {
override protected def batches: Seq[Batch] =
Batch(
"Expression Canonicalization", FixedPoint(100),
IgnoreNamesTypes,
Reorder) :: Nil

/** Remove names and nullability from types. */
protected object IgnoreNamesTypes extends Rule[Expression] {
override def apply(e: Expression): Expression = e transformUp {
case a: AttributeReference =>
AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId)
}
}

/** Collects adjacent commutative operations. */
protected def gatherCommutative(
e: Expression,
f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match {
case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
case other => other :: Nil
}

/** Orders a set of commutative operations by their hash code. */
protected def orderCommutative(
e: Expression,
f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] =
gatherCommutative(e, f).sortBy(_.hashCode())

/** Rearrange expressions that are commutative or associative. */
protected object Reorder extends Rule[Expression] {
override def apply(e: Expression): Expression = e transformUp {
case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add)
case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)

case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l)
case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l)

case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l)
case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l)

case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l)
case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,49 +144,32 @@ abstract class Expression extends TreeNode[Expression] {
*/
def childrenResolved: Boolean = children.forall(_.resolved)

/**
* Returns an expression where a best effort attempt has been made to transform `this` in a way
* that preserves the result but removes cosmetic variations (case sensitivity, ordering for
* commutative operations, etc.) See [[Canonicalize]] for more details.
*
* `deterministic` expressions where `this.canonicalized == other.canonicalized` will always
* evaluate to the same result.
*/
lazy val canonicalized: Expression = Canonicalize.execute(this)

/**
* Returns true when two expressions will always compute the same result, even if they differ
* cosmetically (i.e. capitalization of names in attributes may be different).
*
* See [[Canonicalize]] for more details.
*/
def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
elements1.length == elements2.length && elements1.zip(elements2).forall {
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2
case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq)
case (i1, i2) => i1 == i2
}
}
// Non-deterministic expressions cannot be semantic equal
if (!deterministic || !other.deterministic) return false
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
checkSemantic(elements1, elements2)
}
def semanticEquals(other: Expression): Boolean =
deterministic && other.deterministic && canonicalized == other.canonicalized

/**
* Returns the hash for this expression. Expressions that compute the same result, even if
* they differ cosmetically should return the same hash.
* Returns a `hashCode` for the calculation performed by this expression. Unlike the standard
* `hashCode`, an attempt has been made to eliminate cosmetic differences.
*
* See [[Canonicalize]] for more details.
*/
def semanticHash() : Int = {
def computeHash(e: Seq[Any]): Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var hash: Int = 17
e.foreach(i => {
val h: Int = i match {
case e: Expression => e.semanticHash()
case Some(e: Expression) => e.semanticHash()
case t: Traversable[_] => computeHash(t.toSeq)
case null => 0
case other => other.hashCode()
}
hash = hash * 37 + h
})
hash
}

computeHash(this.productIterator.toSeq)
}
def semanticHash(): Int = canonicalized.hashCode()

/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
Expand Down Expand Up @@ -369,7 +352,6 @@ abstract class UnaryExpression extends Expression {
}
}


/**
* An expression with two inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object ExpressionSet {
/** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */
def apply(expressions: TraversableOnce[Expression]): ExpressionSet = {
val set = new ExpressionSet()
expressions.foreach(set.add)
set
}
}

/**
* A [[Set]] where membership is determined based on a canonical representation of an [[Expression]]
* (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details.
*
* Internally this set uses the canonical representation, but keeps also track of the original
* expressions to ease debugging. Since different expressions can share the same canonical
* representation, this means that operations that extract expressions from this set are only
* guranteed to see at least one such expression. For example:
*
* {{{
* val set = AttributeSet(a + 1, 1 + a)
*
* set.iterator => Iterator(a + 1)
* set.contains(a + 1) => true
* set.contains(1 + a) => true
* set.contains(a + 2) => false
* }}}
*/
class ExpressionSet protected(
protected val baseSet: mutable.Set[Expression] = new mutable.HashSet,
protected val originals: mutable.Buffer[Expression] = new ArrayBuffer)
extends Set[Expression] {

protected def add(e: Expression): Unit = {
if (!baseSet.contains(e.canonicalized)) {
baseSet.add(e.canonicalized)
originals.append(e)
}
}

override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)

override def +(elem: Expression): ExpressionSet = {
val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
newSet.add(elem)
newSet
}

override def -(elem: Expression): ExpressionSet = {
val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized)
val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized)
new ExpressionSet(newBaseSet, newOriginals)
}

override def iterator: Iterator[Expression] = originals.iterator

/**
* Returns a string containing both the post [[Canonicalize]] expressions and the original
* expressions in this set.
*/
def toDebugString: String =
s"""
|baseSet: ${baseSet.mkString(", ")}
|originals: ${originals.mkString(", ")}
""".stripMargin
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,19 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}

/**
* A sequence of expressions that describes the data property of the output rows of this
* operator. For example, if the output of this operator is column `a`, an example `constraints`
* can be `Set(a > 10, a < 20)`.
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
* evaluate to `true` for all rows produced.
*/
lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints)
lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))

/**
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
* based on the given operator's constraint propagation logic. These constraints are then
* canonicalized and filtered automatically to contain only those attributes that appear in the
* [[outputSet]]
* [[outputSet]].
*
* See [[Canonicalize]] for more details.
*/
protected def validConstraints: Set[Expression] = Set.empty

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.IntegerType

class ExpressionSetSuite extends SparkFunSuite {

val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))

val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))

val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

def setTest(size: Int, exprs: Expression*): Unit = {
test(s"expect $size: ${exprs.mkString(", ")}") {
val set = ExpressionSet(exprs)
if (set.size != size) {
fail(set.toDebugString)
}
}
}

def setTestIgnore(size: Int, exprs: Expression*): Unit =
ignore(s"expect $size: ${exprs.mkString(", ")}") {}

// Commutative
setTest(1, aUpper + 1, aLower + 1)
setTest(2, aUpper + 1, aLower + 2)
setTest(2, aUpper + 1, fakeA + 1)
setTest(2, aUpper + 1, bUpper + 1)

setTest(1, aUpper + aLower, aLower + aUpper)
setTest(1, aUpper + bUpper, bUpper + aUpper)
setTest(1,
aUpper + bUpper + 3,
bUpper + 3 + aUpper,
bUpper + aUpper + 3,
Literal(3) + aUpper + bUpper)
setTest(1,
aUpper * bUpper * 3,
bUpper * 3 * aUpper,
bUpper * aUpper * 3,
Literal(3) * aUpper * bUpper)
setTest(1, aUpper === bUpper, bUpper === aUpper)

setTest(1, aUpper + 1 === bUpper, bUpper === Literal(1) + aUpper)


// Not commutative
setTest(2, aUpper - bUpper, bUpper - aUpper)

// Reversable
setTest(1, aUpper > bUpper, bUpper < aUpper)
setTest(1, aUpper >= bUpper, bUpper <= aUpper)

test("add to / remove from set") {
val initialSet = ExpressionSet(aUpper + 1 :: Nil)

assert((initialSet + (aUpper + 1)).size == 1)
assert((initialSet + (aUpper + 2)).size == 2)
assert((initialSet - (aUpper + 1)).size == 0)
assert((initialSet - (aUpper + 2)).size == 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test (initialSet + (aLower + 1)).size and (initialSet - (aLower + 1)).size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added


assert((initialSet + (aLower + 1)).size == 1)
assert((initialSet - (aLower + 1)).size == 0)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1980,9 +1980,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
verifyCallCount(
df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)

// Would be nice if semantic equals for `+` understood commutative
verifyCallCount(
df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/cc @nongli


// Try disabling it via configuration.
sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
Expand Down