Skip to content

Commit

Permalink
fix by using TypeUtils.getInterpretedOrdering
Browse files Browse the repository at this point in the history
  • Loading branch information
bogdanrdc committed Jun 28, 2017
1 parent 1057abe commit a1bceda
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.immutable.TreeSet

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -179,8 +180,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}
case _ =>
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure("Arguments must be same type")
val mismatchOpt = list.find(l => l.dataType != value.dataType)
if (mismatchOpt.isDefined) {
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
TypeCheckResult.TypeCheckSuccess
}
Expand All @@ -189,6 +192,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {

override def children: Seq[Expression] = value +: list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)

override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
Expand All @@ -203,10 +207,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
if (v == evaluatedValue) {
return true
} else if (v == null) {
if (v == null) {
hasNull = true
} else if (ordering.equiv(v, evaluatedValue)) {
return true
}
}
if (hasNull) {
Expand Down Expand Up @@ -265,7 +269,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def nullable: Boolean = child.nullable || hasNull

protected override def nullSafeEval(value: Any): Any = {
if (hset.contains(value)) {
if (set.contains(value)) {
true
} else if (hasNull) {
null
Expand All @@ -274,24 +278,32 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}

def getHSet(): Set[Any] = hset
@transient private[this] lazy val set = child.dataType match {
case _: StructType =>
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty (TypeUtils.getInterpretedOrdering (child.dataType) ) ++ hset
case _ => hset
}

def getSet(): Set[Any] = set

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val setName = classOf[Set[Any]].getName
val InSetName = classOf[InSet].getName
val childGen = child.genCode(ctx)
ctx.references += this
val hsetTerm = ctx.freshName("hset")
val setTerm = ctx.freshName("set")
val hasNullTerm = ctx.freshName("hasNull")
ctx.addMutableState(setName, hsetTerm,
s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
ctx.addMutableState(setName, setTerm,
s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
ctx.addMutableState("boolean", hasNullTerm,
s"$hasNullTerm = ${if (hasNull) "true" else "false"};")
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
boolean ${ev.value} = false;
if (!${ev.isNull}) {
${ev.value} = $hsetTerm.contains(${childGen.value});
${ev.value} = $setTerm.contains(${childGen.value});
if (!${ev.value} && $hasNullTerm) {
${ev.isNull} = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class OptimizeInSuite extends PlanTest {
val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan)
optimizedPlan match {
case Filter(cond, _)
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 =>
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
// pass
case _ => fail("Unexpected result for OptimizedIn")
}
Expand Down
23 changes: 23 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2616,4 +2616,27 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
assert(e.message.contains("Invalid number of arguments"))
}

test("SPARK-21228: InSet incorrect handling of structs") {
withTempView("A") {
// reduce this from the default of 10 so the repro query text is not too long
withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) {
// a relation that has 1 column of struct type with values (1,1), ..., (9, 9)
spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a")
.createOrReplaceTempView("A")
val df = sql(
"""
|SELECT * from
| (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows
| -- the IN will become InSet with a Set of GenericInternalRows
| -- a GenericInternalRow is never equal to an UnsafeRow so the query would
| -- returns 0 results, which is incorrect
| WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L),
| NAMED_STRUCT('a', 3L, 'b', 3L))
""".stripMargin)
df.explain
checkAnswer(df, Row(Row(1, 1)))
}
}
}
}

0 comments on commit a1bceda

Please sign in to comment.