Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21228][SQL] InSet incorrect handling of structs #18455

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.immutable.TreeSet

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -175,20 +176,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
case _ =>
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure("Arguments must be same type")
val mismatchOpt = list.find(l => l.dataType != value.dataType)
if (mismatchOpt.isDefined) {
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
}

override def children: Seq[Expression] = value +: list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)

override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand All @@ -203,10 +207,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
if (v == evaluatedValue) {
return true
} else if (v == null) {
if (v == null) {
hasNull = true
} else if (ordering.equiv(v, evaluatedValue)) {
return true
}
}
if (hasNull) {
Expand Down Expand Up @@ -265,7 +269,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def nullable: Boolean = child.nullable || hasNull

protected override def nullSafeEval(value: Any): Any = {
if (hset.contains(value)) {
if (set.contains(value)) {
true
} else if (hasNull) {
null
Expand All @@ -274,27 +278,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}

def getHSet(): Set[Any] = hset
@transient private[this] lazy val set = child.dataType match {
case _: AtomicType => hset
case _: NullType => hset
case _ =>
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty(TypeUtils.getInterpretedOrdering (child.dataType) ) ++ hset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no space near () please, .getInterpretedOrdering(child.dataType))

}

def getSet(): Set[Any] = set

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val setName = classOf[Set[Any]].getName
val InSetName = classOf[InSet].getName
val childGen = child.genCode(ctx)
ctx.references += this
val hsetTerm = ctx.freshName("hset")
val hasNullTerm = ctx.freshName("hasNull")
ctx.addMutableState(setName, hsetTerm,
s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
val setTerm = ctx.freshName("set")
val setNull = if (hasNull) {
s"""
|if (!${ev.value}) {
| ${ev.isNull} = true;
|}
""".stripMargin
} else {
""
}
ctx.addMutableState(setName, setTerm,
s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
boolean ${ev.value} = false;
if (!${ev.isNull}) {
${ev.value} = $hsetTerm.contains(${childGen.value});
if (!${ev.value} && $hasNullTerm) {
${ev.isNull} = true;
}
${ev.value} = $setTerm.contains(${childGen.value});
$setNull
}
""")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test(s"3VL $name") {
truthTable.foreach {
case (l, r, answer) =>
val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType))
val expr = op(NonFoldableLiteral.create(l, BooleanType),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why these changes? i.e. adding .create

Copy link
Contributor Author

@bogdanrdc bogdanrdc Jul 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that NonFoldableLiteral(java.math.BigDecimal, DecimalType) is keeping the java value while NonFoldableLiteral.create(value,type) does Literal.create which transforms java.math.BigDecimal to Decimal (catalyst type).
It's a bit of weird code. Literal has the same problem. Literal(value,type) is not the same as Literal.create(value,type). So, one should always use .create and never the 2 parameter constructor.
Otherwise you get java values in the execution, which in my case was failing because there is no ordering for non catalyst types.

NonFoldableLiteral.create(r, BooleanType))
checkEvaluation(expr, answer)
}
}
Expand Down Expand Up @@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(false, true) ::
(null, null) :: Nil
notTrueTable.foreach { case (v, answer) =>
checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer)
checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer)
}
checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType)
}
Expand Down Expand Up @@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, null, null) :: Nil)

test("IN") {
checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral(null, IntegerType),
Seq(NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
checkEvaluation(In(Literal(1), Seq.empty), false)
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
true)
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
checkEvaluation(
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
Literal(2)))),
true)

val ns = NonFoldableLiteral(null, StringType)
val ns = NonFoldableLiteral.create(null, StringType)
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
checkEvaluation(In(ns, Seq(ns)), null)
checkEvaluation(In(Literal("a"), Seq(ns)), null)
Expand All @@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
case _ => value
}
}
val input = inputData.map(NonFoldableLiteral(_, t))
val input = inputData.map(NonFoldableLiteral.create(_, t))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
Expand Down Expand Up @@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test("BinaryComparison: null test") {
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
val normalInt = Literal(-1)
val nullInt = NonFoldableLiteral(null, IntegerType)
val nullInt = NonFoldableLiteral.create(null, IntegerType)

def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class OptimizeInSuite extends PlanTest {
val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan)
optimizedPlan match {
case Filter(cond, _)
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 =>
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
// pass
case _ => fail("Unexpected result for OptimizedIn")
}
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2616,4 +2616,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
assert(e.message.contains("Invalid number of arguments"))
}

test("SPARK-21228: InSet incorrect handling of structs") {
withTempView("A") {
// reduce this from the default of 10 so the repro query text is not too long
withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) {
// a relation that has 1 column of struct type with values (1,1), ..., (9, 9)
spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a")
.createOrReplaceTempView("A")
val df = sql(
"""
|SELECT * from
| (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows
| -- the IN will become InSet with a Set of GenericInternalRows
| -- a GenericInternalRow is never equal to an UnsafeRow so the query would
| -- returns 0 results, which is incorrect
| WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L),
| NAMED_STRUCT('a', 3L, 'b', 3L))
""".stripMargin)
checkAnswer(df, Row(Row(1, 1)))
}
}
}
}