Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Spark-SQL] Optimize the Constant Folding for Expression #482

Closed
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,37 +114,37 @@ package object dsl {
def attr = analysis.UnresolvedAttribute(s)

/** Creates a new AttributeReference of type boolean */
def boolean = AttributeReference(s, BooleanType, nullable = false)()
def boolean = AttributeReference(s, BooleanType, nullable = true)()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the rationale for changing all of these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "def boolean" / "def string" should return an Attribute with nullability = true, that does not harm for the correctness, otherwise, it may bring a wrong hint for the optimization of rule NullPropagation.

For example:

val row = new GenericRow(Array[Any] ("a", null))
val c1 = 'a.string.at(0) 
val c2 = 'b.string.at(1) // nullable should be true, otherwise does't reflect the real situation.
assert(evaluate(IsNull(c2), row) == true)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.


/** Creates a new AttributeReference of type byte */
def byte = AttributeReference(s, ByteType, nullable = false)()
def byte = AttributeReference(s, ByteType, nullable = true)()

/** Creates a new AttributeReference of type short */
def short = AttributeReference(s, ShortType, nullable = false)()
def short = AttributeReference(s, ShortType, nullable = true)()

/** Creates a new AttributeReference of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()
def int = AttributeReference(s, IntegerType, nullable = true)()

/** Creates a new AttributeReference of type long */
def long = AttributeReference(s, LongType, nullable = false)()
def long = AttributeReference(s, LongType, nullable = true)()

/** Creates a new AttributeReference of type float */
def float = AttributeReference(s, FloatType, nullable = false)()
def float = AttributeReference(s, FloatType, nullable = true)()

/** Creates a new AttributeReference of type double */
def double = AttributeReference(s, DoubleType, nullable = false)()
def double = AttributeReference(s, DoubleType, nullable = true)()

/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = false)()
def string = AttributeReference(s, StringType, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = false)()
def decimal = AttributeReference(s, DecimalType, nullable = true)()

/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = false)()
def timestamp = AttributeReference(s, TimestampType, nullable = true)()

/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = false)()
def binary = AttributeReference(s, BinaryType, nullable = true)()
}

implicit class DslAttribute(a: AttributeReference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ abstract class Expression extends TreeNode[Expression] {
* - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its
* child is foldable.
*/
// TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs.
def foldable: Boolean = false
def nullable: Boolean
def references: Set[Attribute]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees

abstract sealed class SortDirection
case object Ascending extends SortDirection
Expand All @@ -27,7 +28,10 @@ case object Descending extends SortDirection
* An expression that can be used to sort a tuple. This class extends expression primarily so that
* transformations over expression will descend into its child.
*/
case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
with trees.UnaryNode[Expression] {

override def references = child.references
override def dataType = child.dataType
override def nullable = child.nullable

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
val children = child :: ordinal :: Nil
/** `Null` is returned for invalid ordinals. */
override def nullable = true
override def foldable = child.foldable && ordinal.foldable
override def references = children.flatMap(_.references).toSet
def dataType = child.dataType match {
case ArrayType(dt) => dt
Expand All @@ -40,23 +41,27 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
override def toString = s"$child[$ordinal]"

override def eval(input: Row): Any = {
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = child.eval(input).asInstanceOf[Seq[_]]
val o = ordinal.eval(input).asInstanceOf[Int]
if (baseValue == null) {
null
} else if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
val value = child.eval(input)
if(value == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space after if, below too.

null
} else {
val baseValue = child.eval(input).asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
if (baseValue == null) {
if(key == null) {
null
} else {
baseValue.get(key).orNull
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = value.asInstanceOf[Seq[_]]
val o = key.asInstanceOf[Int]
if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
} else {
val baseValue = value.asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
baseValue.get(key).orNull
}
}
}
}
Expand All @@ -69,7 +74,8 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
type EvaluatedType = Any

def dataType = field.dataType
def nullable = field.nullable
override def nullable = field.nullable
override def foldable = child.foldable

protected def structType = child.dataType match {
case s: StructType => s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ abstract class BinaryPredicate extends BinaryExpression with Predicate {
def nullable = left.nullable || right.nullable
}

case class Not(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
def references = child.references
case class Not(child: Expression) extends UnaryExpression with Predicate {
override def foldable = child.foldable
def nullable = child.nullable
override def toString = s"NOT $child"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.types._
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
Batch("ConstantFolding", Once,
NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyFilters,
Expand Down Expand Up @@ -85,6 +86,71 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
}

/**
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
* equivalent [[catalyst.expressions.Literal Literal]] values. This rule is more specific with
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want Up? I believe this means we are going to call evaluate on each foldable node working up instead of just calling it once at the top.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expression.foldable is ok by traveling from top to bottom, while null propagation is opposite. I've put them into different rule objects (ConstantFolding & NullPropagation).

case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
case e @ Sum(Literal(c, _)) if(c == 0) => Literal(0, e.dataType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not like the previous implementation, I didn't add the null value as Child expression for both Sum and Average in the rule, cause Hive will throw semantic exception.

case e @ Average(Literal(c, _)) if(c == 0) => Literal(0.0, e.dataType)
case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
case e @ Coalesce(children) => {
val newChildren = children.filter(c => c match {
case Literal(null, _) => false
case _ => true
})
if(newChildren.length == 0) {
Literal(null, e.dataType)
} else if(newChildren.length == 1) {
newChildren(0)
} else {
Coalesce(newChildren)
}
}
case e @ If(Literal(v, _), trueValue, falseValue) => if(v == true) trueValue else falseValue
case e @ In(Literal(v, _), list) if(list.exists(c => c match {
case Literal(candidate, _) if(candidate == v) => true
case _ => false
})) => Literal(true, BooleanType)
case e: UnaryMinus => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: Cast => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: Not => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
// Put exceptional cases above if any
case e: BinaryArithmetic => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
case e: BinaryComparison => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
case e: StringRegexExpression => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
}
}
}
/**
* Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with
* equivalent [[catalyst.expressions.Literal Literal]] values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ class ExpressionEvaluationSuite extends FunSuite {
truthTable.foreach {
case (l,r,answer) =>
val expr = op(Literal(l, BooleanType), Literal(r, BooleanType))
val result = expr.eval(null)
if (result != answer)
fail(s"$expr should not evaluate to $result, expected: $answer")
checkEvaluation(expr, answer)
}
}
}
Expand All @@ -131,6 +129,7 @@ class ExpressionEvaluationSuite extends FunSuite {

test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null)
checkEvaluation("abdef" like "abdef", true)
checkEvaluation("a_%b" like "a\\__b", true)
Expand Down Expand Up @@ -159,9 +158,14 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))

checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%")))
}

test("RLIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType) rlike "abdef", null)
checkEvaluation("abdef" rlike Literal(null, StringType), null)
checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null)
checkEvaluation("abdef" rlike "abdef", true)
checkEvaluation("abbbbc" rlike "a.*c", true)

Expand Down Expand Up @@ -257,6 +261,8 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(("abcdef" cast DecimalType).nullable === true)
assert(("abcdef" cast DoubleType).nullable === true)
assert(("abcdef" cast FloatType).nullable === true)

checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null)
}

test("timestamp") {
Expand Down Expand Up @@ -287,5 +293,108 @@ class ExpressionEvaluationSuite extends FunSuite {
// A test for higher precision than millis
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
}

test("null checking") {
val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)

checkEvaluation(IsNull(c1), false, row)
checkEvaluation(IsNotNull(c1), true, row)

checkEvaluation(IsNull(c2), true, row)
checkEvaluation(IsNotNull(c2), false, row)

checkEvaluation(IsNull(Literal(1, ShortType)), false)
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)

checkEvaluation(IsNull(Literal(null, ShortType)), true)
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)

checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)

checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row)
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row)
checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(false, BooleanType),
Literal("a", StringType), Literal("b", StringType)), "b", row)

checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
}

test("complex type") {
val row = new GenericRow(Array[Any](
"^Ba*n", // 0
null.asInstanceOf[String], // 1
new GenericRow(Array[Any]("aa", "bb")), // 2
Map("aa"->"bb"), // 3
Seq("aa", "bb") // 4
))

val typeS = StructType(
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
)
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)

checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
Literal("aa")), "bb", row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
Literal(null, StringType)), null, row)

checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
Literal(1)), "bb", row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
Literal(null, IntegerType)), null, row)

checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
}

test("arithmetic") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(UnaryMinus(c1), -1, row)
checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100)

checkEvaluation(Add(c1, c4), null, row)
checkEvaluation(Add(c1, c2), 3, row)
checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(Add(Literal(null, IntegerType), c2), null, row)
checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
}

test("BinaryComparison") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(LessThan(c1, c4), null, row)
checkEvaluation(LessThan(c1, c2), true, row)
checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row)
checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
}
}

Loading