Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-20121][SQL] simplify NullPropagation with NullIntolerant #17450

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ case class Abs(child: Expression)
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
}

abstract class BinaryArithmetic extends BinaryOperator {
abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

override def dataType: DataType = left.dataType

Expand Down Expand Up @@ -146,7 +146,7 @@ object BinaryArithmetic {
> SELECT 1 _FUNC_ 2;
3
""")
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

Expand Down Expand Up @@ -182,8 +182,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit
> SELECT 2 _FUNC_ 1;
1
""")
case class Subtract(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

Expand Down Expand Up @@ -219,8 +218,7 @@ case class Subtract(left: Expression, right: Expression)
> SELECT 2 _FUNC_ 3;
6
""")
case class Multiply(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType

Expand All @@ -243,8 +241,7 @@ case class Multiply(left: Expression, right: Expression)
1.0
""")
// scalastyle:on line.size.limit
case class Divide(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)

Expand Down Expand Up @@ -324,8 +321,7 @@ case class Divide(left: Expression, right: Expression)
> SELECT 2 _FUNC_ 1.8;
0.2
""")
case class Remainder(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType

Expand Down Expand Up @@ -412,7 +408,7 @@ case class Remainder(left: Expression, right: Expression)
> SELECT _FUNC_(-10, 3);
2
""")
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {

override def toString: String = s"pmod($left, $right)"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ trait ExtractValue extends Expression
* For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
*/
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
extends UnaryExpression with ExtractValue {
extends UnaryExpression with ExtractValue with NullIntolerant {

lazy val childSchema = child.dataType.asInstanceOf[StructType]

Expand Down Expand Up @@ -152,7 +152,7 @@ case class GetArrayStructFields(
field: StructField,
ordinal: Int,
numFields: Int,
containsNull: Boolean) extends UnaryExpression with ExtractValue {
containsNull: Boolean) extends UnaryExpression with ExtractValue with NullIntolerant {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
Expand Down Expand Up @@ -213,7 +213,7 @@ case class GetArrayStructFields(
* We need to do type checking here as `ordinal` expression maybe unresolved.
*/
case class GetArrayItem(child: Expression, ordinal: Expression)
extends BinaryExpression with ExpectsInputTypes with ExtractValue {
extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant {

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
Expand Down Expand Up @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
* We need to do type checking here as `key` expression maybe unresolved.
*/
case class GetMapValue(child: Expression, key: Expression)
extends BinaryExpression with ImplicitCastInputTypes with ExtractValue {
extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant {

private def keyType = child.dataType.asInstanceOf[MapType].keyType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,5 @@ package object expressions {
* input will result in null output). We will use this information during constructing IsNotNull
* constraints.
*/
trait NullIntolerant
trait NullIntolerant extends Expression
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String


trait StringRegexExpression extends ImplicitCastInputTypes {
self: BinaryExpression =>
abstract class StringRegexExpression extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant {

def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean
Expand Down Expand Up @@ -69,8 +69,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
*/
@ExpressionDescription(
usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.")
case class Like(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
case class Like(left: Expression, right: Expression) extends StringRegexExpression {

override def escape(v: String): String = StringUtils.escapeLikeRegex(v)

Expand Down Expand Up @@ -122,8 +121,7 @@ case class Like(left: Expression, right: Expression)

@ExpressionDescription(
usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.")
case class RLike(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {

override def escape(v: String): String = v
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx
}

/** A base trait for functions that compare two strings, returning a boolean. */
trait StringPredicate extends Predicate with ImplicitCastInputTypes {
self: BinaryExpression =>
abstract class StringPredicate extends BinaryExpression
with Predicate with ImplicitCastInputTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing with NullIntolerant here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is just to simplify the existing rule NullPropagation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above StringRegexExpression, similar to it, in order to simplify the NullPropagation, we need to add NullIntolerant, so it can propagate null value...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I finally got your point. StringPredicate is used for inferring the null constants in the rule NullPropagation. Thus, we should mark it as NullIntolerant .

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. :-)


def compare(l: UTF8String, r: UTF8String): Boolean

Expand All @@ -313,8 +313,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes {
/**
* A function that returns true if the string `left` contains the string `right`.
*/
case class Contains(left: Expression, right: Expression)
extends BinaryExpression with StringPredicate {
case class Contains(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)")
Expand All @@ -324,8 +323,7 @@ case class Contains(left: Expression, right: Expression)
/**
* A function that returns true if the string `left` starts with the string `right`.
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryExpression with StringPredicate {
case class StartsWith(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)")
Expand All @@ -335,8 +333,7 @@ case class StartsWith(left: Expression, right: Expression)
/**
* A function that returns true if the string `left` ends with the string `right`.
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryExpression with StringPredicate {
case class EndsWith(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)")
Expand Down Expand Up @@ -1122,7 +1119,7 @@ case class StringSpace(child: Expression)
""")
// scalastyle:on line.size.limit
case class Substring(str: Expression, pos: Expression, len: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the function SUBSTRING null-intolerant? What is the return value if str is a null value?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result can be null; if any argument is null, the result is the null value.

Ref: https://www.ibm.com/support/knowledgecenter/en/SSEPEK_10.0.0/sqlref/src/tpc/db2z_bif_substr.html

Copy link
Contributor

@nsyca nsyca Mar 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be confused with the terminologies: NullIntolerant expression versus "null-intolerant predicate". But if SUBSTRING is marked null-intolerant expression, why do we not mark the class of string functions such as STARTSWITH, etc. the same way? Am I missing anything here?

Copy link
Member

@gatorsmile gatorsmile Mar 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should mark NullIntolerant to the other expressions, if possible, and also update the document.

Copy link
Member

@gatorsmile gatorsmile Mar 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nsyca If you have a bandwidth, could you please review all the expressions and see whether they can be marked as NullIntolerant?

You can check the impl of these expressions and compare them with the corresponding ones in the other RDBMS. Thanks!

Below is a ref PR you can use: https://github.com/apache/spark/pull/15850/files. You can continue my work if you want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will certainly take a look. On a second thought, since "most" of the SQL functions are null-intolerant, isn't easier to mark only functions that are null-tolerant such as ISNOTNULL? I am just pitching an idea here, not indicating we should abandon this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the beginning, when we introduce NullIntolerant , @marmbrus said we should do it more carefully. We prefer to using the white-list solution for avoiding hidden bugs. Marking it NullIntolerant if and only if we are ensure that they are null intolerant. Thus, when we doing it, we should also add the corresponding test cases and documents.


def this(str: Expression, pos: Expression) = {
this(str, pos, Literal(Integer.MAX_VALUE))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,35 +347,30 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] {
private def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
private def isNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => true
case _ => false
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) =>
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
ae.copy(aggregateFunction = Count(Literal(1)))

case IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)

case EqualNullSafe(Literal(null, _), r) => IsNull(r)
case EqualNullSafe(l, Literal(null, _)) => IsNull(l)

// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
val newChildren = children.filter(nonNullLiteral)
val newChildren = children.filterNot(isNullLiteral)
if (newChildren.isEmpty) {
Literal.create(null, e.dataType)
} else if (newChildren.length == 1) {
Expand All @@ -384,33 +379,13 @@ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] {
Coalesce(newChildren)
}

case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType)
case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType)

// Put exceptional cases above if any
case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)

case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)

case e: StringRegexExpression => e.children match {
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
case _ => e
}

case e: StringPredicate => e.children match {
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
case _ => e
}

// If the value expression is NULL then transform the In expression to
// Literal(null)
case In(Literal(null, _), list) => Literal.create(null, BooleanType)

// Put exceptional cases above if any
Copy link
Member

@gatorsmile gatorsmile Mar 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Attribute is also NullIntolerant Maybe add a comment?
Non-leaf NullIntolerant expressions will return null, if at least one of its children is a null literal.

case e: NullIntolerant if e.children.exists(isNullLiteral) =>
Literal.create(null, e.dataType)
}
}
}
Expand Down