Skip to content

Commit

Permalink
[SPARK-31553][SQL] Revert "[SPARK-29048] Improve performance on Colum…
Browse files Browse the repository at this point in the history
…n.isInCollection() with a large size collection"

### What changes were proposed in this pull request?
This reverts commit 5631a96.

Closes #28328

### Why are the changes needed?
The PR  #25754 introduced a bug in `isInCollection`. For example, if the SQL config `spark.sql.optimizer.inSetConversionThreshold`is set to 10 (by default):
```scala
val set = (0 to 20).map(_.toString).toSet
val data = Seq("1").toDF("x")
data.select($"x".isInCollection(set).as("isInCollection")).show()
```
The function must return **'true'** because "1" is in the set of "0" ... "20" but it returns "false":
```
+--------------+
|isInCollection|
+--------------+
|         false|
+--------------+
```

### Does this PR introduce any user-facing change?
Yes

### How was this patch tested?
```
$ ./build/sbt "test:testOnly *ColumnExpressionSuite"
```

Closes #28388 from MaxGekk/fix-isInCollection-revert.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit b7cabc8)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
MaxGekk authored and cloud-fan committed Apr 28, 2020
1 parent d918477 commit efe700c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 37 deletions.
10 changes: 1 addition & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

private[sql] object Column {
Expand Down Expand Up @@ -827,14 +826,7 @@ class Column(val expr: Expression) extends Logging {
* @group expr_ops
* @since 2.4.0
*/
def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr {
val hSet = values.toSet[Any]
if (hSet.size > SQLConf.get.optimizerInSetConversionThreshold) {
InSet(expr, hSet)
} else {
In(expr, values.toSeq.map(lit(_).expr))
}
}
def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*)

/**
* A boolean expression that is evaluated to true if the value of this expression is contained
Expand Down
Expand Up @@ -26,7 +26,7 @@ import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
import org.scalatest.Matchers._

import org.apache.spark.sql.catalyst.expressions.{In, InSet, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{InSet, Literal, NamedExpression}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -455,36 +455,25 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {

test("isInCollection: Scala Collection") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
// Test with different types of collections
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

Seq(1, 2).foreach { conf =>
withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> conf.toString) {
if (conf <= 1) {
assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[InSet], "Expect expr to be InSet")
} else {
assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[In], "Expect expr to be In")
}

// Test with different types of collections
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isInCollection(Seq($"b")))
}
Seq("cannot resolve",
"due to data type mismatch: Arguments must be same type but were").foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}
val e = intercept[AnalysisException] {
df2.filter($"a".isInCollection(Seq($"b")))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("&&") {
Expand Down

0 comments on commit efe700c

Please sign in to comment.