Skip to content

Commit

Permalink
[SPARK-21228][SQL][BRANCH-2.2] InSet incorrect handling of structs
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This is backport of apache#18455
When data type is struct, InSet now uses TypeUtils.getInterpretedOrdering (similar to EqualTo) to build a TreeSet. In other cases it will use a HashSet as before (which should be faster). Similarly, In.eval uses Ordering.equiv instead of equals.

## How was this patch tested?
New test in SQLQuerySuite.

Author: Bogdan Raducanu <bogdan@databricks.com>

Closes apache#18563 from bogdanrdc/SPARK-21228-BRANCH2.2.
  • Loading branch information
bogdanrdc authored and MatthewRBruce committed Jul 31, 2018
1 parent abfdf1b commit 460e127
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 34 deletions.
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.immutable.TreeSet

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -164,19 +165,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
case _ =>
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure("Arguments must be same type")
val mismatchOpt = list.find(l => l.dataType != value.dataType)
if (mismatchOpt.isDefined) {
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
}

override def children: Seq[Expression] = value +: list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)

override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand All @@ -191,10 +195,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
if (v == evaluatedValue) {
return true
} else if (v == null) {
if (v == null) {
hasNull = true
} else if (ordering.equiv(v, evaluatedValue)) {
return true
}
}
if (hasNull) {
Expand Down Expand Up @@ -253,7 +257,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def nullable: Boolean = child.nullable || hasNull

protected override def nullSafeEval(value: Any): Any = {
if (hset.contains(value)) {
if (set.contains(value)) {
true
} else if (hasNull) {
null
Expand All @@ -262,27 +266,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}

def getHSet(): Set[Any] = hset
@transient private[this] lazy val set = child.dataType match {
case _: AtomicType => hset
case _: NullType => hset
case _ =>
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
}

def getSet(): Set[Any] = set

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val setName = classOf[Set[Any]].getName
val InSetName = classOf[InSet].getName
val childGen = child.genCode(ctx)
ctx.references += this
val hsetTerm = ctx.freshName("hset")
val hasNullTerm = ctx.freshName("hasNull")
ctx.addMutableState(setName, hsetTerm,
s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
val setTerm = ctx.freshName("set")
val setNull = if (hasNull) {
s"""
|if (!${ev.value}) {
| ${ev.isNull} = true;
|}
""".stripMargin
} else {
""
}
ctx.addMutableState(setName, setTerm,
s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
boolean ${ev.value} = false;
if (!${ev.isNull}) {
${ev.value} = $hsetTerm.contains(${childGen.value});
if (!${ev.value} && $hasNullTerm) {
${ev.isNull} = true;
}
${ev.value} = $setTerm.contains(${childGen.value});
$setNull
}
""")
}
Expand Down
Expand Up @@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test(s"3VL $name") {
truthTable.foreach {
case (l, r, answer) =>
val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType))
val expr = op(NonFoldableLiteral.create(l, BooleanType),
NonFoldableLiteral.create(r, BooleanType))
checkEvaluation(expr, answer)
}
}
Expand Down Expand Up @@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(false, true) ::
(null, null) :: Nil
notTrueTable.foreach { case (v, answer) =>
checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer)
checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer)
}
checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType)
}
Expand Down Expand Up @@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, null, null) :: Nil)

test("IN") {
checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral(null, IntegerType),
Seq(NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
checkEvaluation(In(Literal(1), Seq.empty), false)
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
true)
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
checkEvaluation(
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
Literal(2)))),
true)

val ns = NonFoldableLiteral(null, StringType)
val ns = NonFoldableLiteral.create(null, StringType)
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
checkEvaluation(In(ns, Seq(ns)), null)
checkEvaluation(In(Literal("a"), Seq(ns)), null)
Expand All @@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
case _ => value
}
}
val input = inputData.map(NonFoldableLiteral(_, t))
val input = inputData.map(NonFoldableLiteral.create(_, t))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
Expand Down Expand Up @@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test("BinaryComparison: null test") {
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
val normalInt = Literal(-1)
val nullInt = NonFoldableLiteral(null, IntegerType)
val nullInt = NonFoldableLiteral.create(null, IntegerType)

def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)
Expand Down
Expand Up @@ -166,7 +166,7 @@ class OptimizeInSuite extends PlanTest {
val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan)
optimizedPlan match {
case Filter(cond, _)
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 =>
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
// pass
case _ => fail("Unexpected result for OptimizedIn")
}
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Expand Up @@ -2624,4 +2624,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
assert(e.message.contains("Invalid number of arguments"))
}

test("SPARK-21228: InSet incorrect handling of structs") {
withTempView("A") {
// reduce this from the default of 10 so the repro query text is not too long
withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) {
// a relation that has 1 column of struct type with values (1,1), ..., (9, 9)
spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a")
.createOrReplaceTempView("A")
val df = sql(
"""
|SELECT * from
| (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows
| -- the IN will become InSet with a Set of GenericInternalRows
| -- a GenericInternalRow is never equal to an UnsafeRow so the query would
| -- returns 0 results, which is incorrect
| WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L),
| NAMED_STRUCT('a', 3L, 'b', 3L))
""".stripMargin)
checkAnswer(df, Row(Row(1, 1)))
}
}
}
}

0 comments on commit 460e127

Please sign in to comment.