Skip to content

Commit

Permalink
[SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

JIRA: https://issues.apache.org/jira/browse/SPARK-13995

We infer relative `IsNotNull` constraints from logical plan's expressions in `constructIsNotNullConstraints` now. However, we don't consider the case of (nested) `Cast`.

For example:

    val tr = LocalRelation('a.int, 'b.long)
    val plan = tr.where('a.attr === 'b.attr).analyze

Then, the plan's constraints will have `IsNotNull(Cast(resolveColumn(tr, "a"), LongType))`, instead of `IsNotNull(resolveColumn(tr, "a"))`. This PR fixes it.

Besides, as `IsNotNull` constraints are most useful for `Attribute`, we should do recursing through any `Expression` that is null intolerant and construct `IsNotNull` constraints for all `Attribute`s under these Expressions.

For example, consider the following constraints:

    val df = Seq((1,2,3)).toDF("a", "b", "c")
    df.where("a + b = c").queryExecution.analyzed.constraints

The inferred isnotnull constraints should be isnotnull(a), isnotnull(b), isnotnull(c), instead of isnotnull(a + c) and isnotnull(c).

## How was this patch tested?

Test is added into `ConstraintPropagationSuite`.

Author: Liang-Chi Hsieh <simonh@tw.ibm.com>

Closes #11809 from viirya/constraint-cast.
  • Loading branch information
viirya authored and marmbrus committed Apr 1, 2016
1 parent 381358f commit df68beb
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ object Cast {
}

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {

override def toString: String = s"cast($child as ${dataType.simpleString})"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval


case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class UnaryMinus(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

Expand Down Expand Up @@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def sql: String = s"(-${child.sql})"
}

case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class UnaryPositive(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def prettyName: String = "positive"

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
Expand All @@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
extended = "> SELECT _FUNC_('-1');\n1")
case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class Abs(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

Expand Down Expand Up @@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic {
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

Expand Down Expand Up @@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
case class Subtract(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

Expand Down Expand Up @@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
case class Multiply(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand All @@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}

case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
case class Divide(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand Down Expand Up @@ -269,7 +275,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}

case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
case class Remainder(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand Down Expand Up @@ -457,7 +464,7 @@ case class MinOf(left: Expression, right: Expression)
override def symbol: String = "min"
}

case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {

override def toString: String = s"pmod($left, $right)"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ trait NamedExpression extends Expression {
}
}

abstract class Attribute extends LeafExpression with NamedExpression {
abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant {

override def references: AttributeSet = AttributeSet(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,11 @@ package object expressions {
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
}
}

/**
* When an expression inherits this, meaning the expression is null intolerant (i.e. any null
* input will result in null output). We will use this information during constructing IsNotNull
* constraints.
*/
trait NullIntolerant
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ trait PredicateHelper {


case class Not(child: Expression)
extends UnaryExpression with Predicate with ImplicitCastInputTypes {
extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant {

override def toString: String = s"NOT $child"

Expand Down Expand Up @@ -402,7 +402,8 @@ private[sql] object Equality {
}


case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
case class EqualTo(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = AnyDataType

Expand Down Expand Up @@ -467,7 +468,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}


case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
case class LessThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -479,7 +481,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
}


case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
case class LessThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -491,7 +494,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
}


case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
case class GreaterThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -503,7 +507,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
}


case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
case class GreaterThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* returns a constraint of the form `isNotNull(a)`
*/
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
var isNotNullConstraints = Set.empty[Expression]

// First, we propagate constraints if the condition consists of equality and ranges. For all
// other cases, we return an empty set of constraints
constraints.foreach {
case EqualTo(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case GreaterThan(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case GreaterThanOrEqual(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case LessThan(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case LessThanOrEqual(l, r) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case Not(EqualTo(l, r)) =>
isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
case _ => // No inference
}
// First, we propagate constraints from the null intolerant expressions.
var isNotNullConstraints: Set[Expression] =
constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))

// Second, we infer additional constraints from non-nullable attributes that are part of the
// operator's output
Expand All @@ -72,6 +56,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
isNotNullConstraints -- constraints
}

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
expr.children.flatMap(scanNullIntolerantExpr)
case _ => Seq.empty[Attribute]
}

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}

class ConstraintPropagationSuite extends SparkFunSuite {

Expand Down Expand Up @@ -219,6 +219,89 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "b")))))
}

test("infer constraints on cast") {
val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
verifyConstraints(
tr.where('a.attr === 'b.attr &&
'c.attr + 100 > 'd.attr &&
IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")),
IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
}

test("infer isnotnull constraints from compound expressions") {
val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
verifyConstraints(
tr.where('a.attr + 'b.attr === 'c.attr &&
IsNotNull(
Cast(
Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
Cast(resolveColumn(tr, "c"), LongType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "e")),
IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))

verifyConstraints(
tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
Cast(Cast(10, LongType), DoubleType) ===
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))

verifyConstraints(
tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
Cast(Cast(10, LongType), DoubleType) <
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))

verifyConstraints(
tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
ExpressionSet(Seq(
(Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
(Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
Cast(resolveColumn(tr, "e") * 1000, LongType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))

// The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
verifyConstraints(
tr.where('a.attr === 'c.attr &&
IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
ExpressionSet(Seq(
resolveColumn(tr, "a") === resolveColumn(tr, "c"),
IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "c")))))
}

test("infer IsNotNull constraints from non-nullable attributes") {
val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
AttributeReference("c", StringType, nullable = false)())
Expand Down

0 comments on commit df68beb

Please sign in to comment.