Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31553][SQL] Fix isInCollection for collection sizes above the optimisation threshold #28328

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -426,10 +426,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
* Optimized version of In clause, when all filter values of In clause are
* static.
*/
case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
case class InSet(
child: Expression,
hset: Set[Any],
hsetElemType: DataType) extends UnaryExpression with Predicate {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching internal Catalyst's types to external types is ambiguous. For example,
Long -> Long
Long -> Timestamp

Also type of child can be unknown when InSet has to know Catalyst's type of hset elements.

hsetElemType is needed to eliminate the ambiguity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can make this Option[DataType] because only a few things are ambiguous?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can but if a caller passes None, InSet will be not able to infer elem types when child.dataType is NullType like in this case. dataType returns NullType if child is PrettyAttribute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when hsetElemType can be different from child.dataType?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When InSet is created from isInCollection, in that case child.dataType is NullType. For example, it is NullType in the test https://github.com/apache/spark/pull/28328/files#diff-aa655ba249e00d2591b21cf6a360cf82R886 because child is PrettyAttribute when the sql method is called.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And InSet.sql() is called from Dataset.select _.named:

Project(untypedCols.map(_.named), logicalPlan)

The named method calls toPrettySQL(expr):

case expr: Expression => Alias(expr, toPrettySQL(expr))()

The toPrettySQL method calls sql:

def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql


require(hset != null, "hset could not be null")

override def checkInputDataTypes(): TypeCheckResult = {
if (!DataType.equalsStructurally(child.dataType, hsetElemType, ignoreNullability = true)) {
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${child.dataType.catalogString} != ${hsetElemType.catalogString}")
} else {
TypeUtils.checkForOrderingExpr(child.dataType, s"function $prettyName")
}
}

override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"

@transient private[this] lazy val hasNull: Boolean = hset.contains(null)
Expand All @@ -446,12 +458,12 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}

@transient lazy val set: Set[Any] = child.dataType match {
@transient lazy val set: Set[Any] = hsetElemType match {
case t: AtomicType if !t.isInstanceOf[BinaryType] => hset
case _: NullType => hset
case _ =>
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ (hset - null)
TreeSet.empty(TypeUtils.getInterpretedOrdering(hsetElemType)) ++ (hset - null)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand All @@ -462,7 +474,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}

private def canBeComputedUsingSwitch: Boolean = child.dataType match {
private def canBeComputedUsingSwitch: Boolean = hsetElemType match {
case ByteType | ShortType | IntegerType | DateType => true
case _ => false
}
Expand Down Expand Up @@ -521,7 +533,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def sql: String = {
val valueSQL = child.sql
val listSQL = hset.toSeq
.map(elem => Literal(convertToScala(elem, child.dataType)).sql)
.map(elem => Literal(convertToScala(elem, hsetElemType)).sql)
.mkString(", ")
s"($valueSQL IN ($listSQL))"
}
Expand Down
Expand Up @@ -251,7 +251,7 @@ object OptimizeIn extends Rule[LogicalPlan] {
EqualTo(v, newList.head)
} else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) {
val hSet = newList.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
InSet(v, HashSet() ++ hSet, v.dataType)
} else if (newList.length < list.length) {
expr.copy(list = newList)
} else { // newList.length == list.length && newList.length > 1
Expand Down
Expand Up @@ -172,7 +172,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
val hSet = expList.map(e => e.eval())
evaluateInSet(ar, HashSet() ++ hSet, update)

case InSet(ar: Attribute, set) =>
case InSet(ar: Attribute, set, _) =>
evaluateInSet(ar, set, update)

// In current stage, we don't have advanced statistics such as sketches or histograms.
Expand Down
Expand Up @@ -130,7 +130,9 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
private def checkInAndInSet(in: In, expected: Any): Unit = {
// expecting all in.list are Literal or NonFoldableLiteral.
checkEvaluation(in, expected)
checkEvaluation(InSet(in.value, HashSet() ++ in.list.map(_.eval())), expected)
checkEvaluation(
InSet(in.value, HashSet() ++ in.list.map(_.eval()), in.value.dataType),
expected)
}

test("basic IN/INSET predicate test") {
Expand All @@ -154,7 +156,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal(2)))),
true)
checkEvaluation(
And(InSet(Literal(1), HashSet(1, 2)), InSet(Literal(2), Set(1, 2))),
And(InSet(Literal(1), HashSet(1, 2), IntegerType), InSet(Literal(2), Set(1, 2), IntegerType)),
true)

val ns = NonFoldableLiteral.create(null, StringType)
Expand Down Expand Up @@ -256,12 +258,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {

val nullLiteral = Literal(null, presentValue.dataType)

checkEvaluation(InSet(nullLiteral, values), expected = null)
checkEvaluation(InSet(nullLiteral, values + null), expected = null)
checkEvaluation(InSet(presentValue, values), expected = true)
checkEvaluation(InSet(presentValue, values + null), expected = true)
checkEvaluation(InSet(absentValue, values), expected = false)
checkEvaluation(InSet(absentValue, values + null), expected = null)
checkEvaluation(InSet(nullLiteral, values, nullLiteral.dataType), expected = null)
checkEvaluation(InSet(nullLiteral, values + null, nullLiteral.dataType), expected = null)
checkEvaluation(InSet(presentValue, values, presentValue.dataType), expected = true)
checkEvaluation(InSet(presentValue, values + null, presentValue.dataType), expected = true)
checkEvaluation(InSet(absentValue, values, absentValue.dataType), expected = false)
checkEvaluation(InSet(absentValue, values + null, absentValue.dataType), expected = null)
}

def checkAllTypes(): Unit = {
Expand Down Expand Up @@ -498,7 +500,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-22693: InSet should not use global variables") {
val ctx = new CodegenContext
InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx)
InSet(Literal(1), Set(1, 2, 3, 4), IntegerType).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}

Expand Down Expand Up @@ -535,7 +537,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-29100: InSet with empty input set") {
val row = create_row(1)
val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty)
val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty, IntegerType)
checkEvaluation(inSet, false, row)
}
}
Expand Up @@ -85,7 +85,7 @@ class OptimizeInSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
.where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet, IntegerType))
.analyze

comparePlans(optimized, correctAnswer)
Expand Down
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{ColumnStatsM
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* In this test suite, we test predicates containing the following operators:
Expand Down Expand Up @@ -352,15 +353,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {

test("cint IN (3, 4, 5)") {
validateEstimatedStats(
Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)),
Filter(InSet(attrInt, Set(3, 4, 5), IntegerType), childStatsTestPlan(Seq(attrInt), 10L)),
Seq(attrInt -> ColumnStat(distinctCount = Some(3), min = Some(3), max = Some(5),
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 3)
}

test("evaluateInSet with all zeros") {
validateEstimatedStats(
Filter(InSet(attrString, Set(3, 4, 5)),
Filter(InSet(attrString, Set(3, 4, 5), IntegerType),
StatsTestPlan(Seq(attrString), 0,
AttributeMap(Seq(attrString ->
ColumnStat(distinctCount = Some(0), min = None, max = None,
Expand All @@ -371,7 +372,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {

test("evaluateInSet with string") {
validateEstimatedStats(
Filter(InSet(attrString, Set("A0")),
Filter(InSet(attrString, Set(UTF8String.fromString("A0")), StringType),
StatsTestPlan(Seq(attrString), 10,
AttributeMap(Seq(attrString ->
ColumnStat(distinctCount = Some(10), min = None, max = None,
Expand All @@ -383,14 +384,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase {

test("cint NOT IN (3, 4, 5)") {
validateEstimatedStats(
Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)),
Filter(Not(InSet(attrInt, Set(3, 4, 5), IntegerType)), childStatsTestPlan(Seq(attrInt), 10L)),
Seq(attrInt -> colStatInt.copy(distinctCount = Some(7))),
expectedRowCount = 7)
}

test("cbool IN (true)") {
validateEstimatedStats(
Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)),
Filter(InSet(attrBool, Set(true), BooleanType), childStatsTestPlan(Seq(attrBool), 10L)),
Seq(attrBool -> ColumnStat(distinctCount = Some(1), min = Some(true), max = Some(true),
nullCount = Some(0), avgLen = Some(1), maxLen = Some(1))),
expectedRowCount = 5)
Expand Down Expand Up @@ -510,7 +511,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt))
)
validateEstimatedStats(
Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
Filter(InSet(attrInt, Set(1, 2, 3, 4, 5), IntegerType), cornerChildStatsTestplan),
Seq(attrInt -> ColumnStat(distinctCount = Some(2), min = Some(1), max = Some(5),
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))),
expectedRowCount = 2)
Expand Down
9 changes: 5 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Expand Up @@ -828,11 +828,12 @@ class Column(val expr: Expression) extends Logging {
* @since 2.4.0
*/
def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr {
val hSet = values.toSet[Any]
if (hSet.size > SQLConf.get.optimizerInSetConversionThreshold) {
InSet(expr, hSet)
val exprValues = values.toSeq.map(lit(_).expr)
if (exprValues.size > SQLConf.get.optimizerInSetConversionThreshold) {
val elemType = exprValues.headOption.map(_.dataType).getOrElse(NullType)
InSet(expr, exprValues.map(_.eval()).toSet, elemType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we make sure the expr has the same data type as exprValues? Do we have a type coercion rule for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure, we need something similar to In.checkInputDataTypes() in InSet:

override def checkInputDataTypes(): TypeCheckResult = {
val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType,
ignoreNullability = true))
if (mismatchOpt.isDefined) {
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}")
} else {
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}

I could add such check in the PR if you don't mind.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added similar check to InSet

} else {
In(expr, values.toSeq.map(lit(_).expr))
In(expr, exprValues)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is caused by SPARK-29048 (Improve performance on Column.isInCollection() with a large size collection, #25754 ) and only affects 3.0.0, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for confirming, @MaxGekk .
cc @WeichenXu123 and @gatorsmile

}
}

Expand Down
Expand Up @@ -479,7 +479,7 @@ object DataSourceStrategy {
case expressions.LessThanOrEqual(Literal(v, t), PushableColumn(name)) =>
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))

case expressions.InSet(e @ PushableColumn(name), set) =>
case expressions.InSet(e @ PushableColumn(name), set, _) =>
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
Some(sources.In(name, set.toArray.map(toScala)))

Expand Down
Expand Up @@ -89,7 +89,7 @@ object FileSourceStrategy extends Strategy with Logging {
case expressions.In(a: Attribute, list)
if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow)))
case expressions.InSet(a: Attribute, hset)
case expressions.InSet(a: Attribute, hset, _)
if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow)))
case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
Expand Down
Expand Up @@ -159,7 +159,7 @@ case class InSubqueryExec(

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
prepareResult()
InSet(child, result.toSet).doGenCode(ctx, ev)
InSet(child, result.toSet, child.dataType).doGenCode(ctx, ev)
}

override lazy val canonicalized: InSubqueryExec = {
Expand Down
Expand Up @@ -483,6 +483,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
"due to data type mismatch: Arguments must be same type but were").foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
val errMsg = intercept[AnalysisException] {
df.select($"a".isInCollection(Seq(0, 1).map(new java.sql.Timestamp(_)))).collect()
}.getMessage
assert(errMsg.contains("Arguments must be same type"))
}
}
}
Expand Down Expand Up @@ -872,7 +876,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
}

test("SPARK-31563: sql of InSet for UTF8String collection") {
val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString))
val inSet = InSet(Literal("a"), Set("a", "b").map(UTF8String.fromString), StringType)
assert(inSet.sql === "('a' IN ('a', 'b'))")
}

test("SPARK-31553: isInCollection for collection sizes above a threshold") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @MaxGekk .

cc @aokolnychyi and @dbtsai

val threshold = 100
withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> threshold.toString) {
val set = (0 until 2 * threshold).map(_.toString).toSet
val elem = "10"
val data = Seq(elem).toDF("x")
assert(set.contains(elem))
checkAnswer(data.select($"x".isInCollection(set)), Row(true))
}
}
}
Expand Up @@ -110,7 +110,9 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession {
testTranslateFilter(LessThanOrEqual(1, attrInt),
Some(sources.GreaterThanOrEqual(intColName, 1)))

testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In(intColName, Array(1, 2, 3))))
testTranslateFilter(
InSet(attrInt, Set(1, 2, 3), IntegerType),
Some(sources.In(intColName, Array(1, 2, 3))))

testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In(intColName, Array(1, 2, 3))))

Expand Down
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.BitSet

Expand Down Expand Up @@ -188,8 +189,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
df)

// Case 4: InSet
val inSetExpr = expressions.InSet($"j".expr,
Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr))
val inSetExpr = expressions.InSet(
$"j".expr,
Set(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3).map(lit(_).expr),
IntegerType)
checkPrunedAnswers(
bucketSpec,
bucketValues = Seq(bucketValue, bucketValue + 1, bucketValue + 2, bucketValue + 3),
Expand Down
Expand Up @@ -740,7 +740,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
if useAdvanced =>
Some(convertInToOr(name, values))

case InSet(ExtractAttribute(SupportedAttribute(name)), ExtractableValues(values))
case InSet(ExtractAttribute(SupportedAttribute(name)), ExtractableValues(values), _)
if useAdvanced =>
Some(convertInToOr(name, values))

Expand Down
Expand Up @@ -213,7 +213,7 @@ class HivePartitionFilteringSuite(version: String)
0 to 4,
"aa" :: "ab" :: "ba" :: "bb" :: Nil, {
case expr @ In(v, list) if expr.inSetConvertible =>
InSet(v, list.map(_.eval(EmptyRow)).toSet)
InSet(v, list.map(_.eval(EmptyRow)).toSet, v.dataType)
})
}

Expand All @@ -225,7 +225,7 @@ class HivePartitionFilteringSuite(version: String)
0 to 4,
"aa" :: "ab" :: "ba" :: "bb" :: Nil, {
case expr @ In(v, list) if expr.inSetConvertible =>
InSet(v, list.map(_.eval(EmptyRow)).toSet)
InSet(v, list.map(_.eval(EmptyRow)).toSet, v.dataType)
})
}

Expand All @@ -244,7 +244,7 @@ class HivePartitionFilteringSuite(version: String)
0 to 4,
"ab" :: "ba" :: Nil, {
case expr @ In(v, list) if expr.inSetConvertible =>
InSet(v, list.map(_.eval(EmptyRow)).toSet)
InSet(v, list.map(_.eval(EmptyRow)).toSet, v.dataType)
})
}

Expand Down