diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 6af478e390f33..d3a829fe511dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -222,7 +222,10 @@ object OptimizeIn extends Rule[LogicalPlan] { // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case In(v, Seq(elem @ Literal(_, _))) => EqualTo(v, elem) + case In(v, Seq(elem @ Literal(_, _))) => + // `Expression` like `ListQuery` contains subquery which can not + // be converted into `EqualTo`. Only `Literal` is converted for safety. + EqualTo(v, elem) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 5e603f6f9145f..b3e59f53ee3de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -789,21 +789,21 @@ class Column(val expr: Expression) extends Logging { /** * A boolean expression that is evaluated to true if the value of this expression is contained - * by the provided Set. + * by the provided collection. * * @group expr_ops * @since 2.4.0 */ - def isinSet(values: scala.collection.Set[_]): Column = isin(values.toSeq: _*) + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) /** * A boolean expression that is evaluated to true if the value of this expression is contained - * by the provided Set. + * by the provided collection. * * @group java_expr_ops * @since 2.4.0 */ - def isinSet(values: java.util.Set[_]): Column = isinSet(values.asScala) + def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 69afd0715176a..bd58496288d9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -405,34 +405,42 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("isinSet: Scala Set") { + test("isInCollection: Scala Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".isinSet(Set(1, 2))), + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2))), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, 2))), + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, 1))), + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - // Auto casting should work with mixture of different types in Set - checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2"))), + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2"))), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong))), + checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, "1"))), + checkAnswer(df.filter($"a".isInCollection(Seq(3, "1"))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"b".isinSet(Set("y", "x"))), + checkAnswer(df.filter($"b".isInCollection(Seq("y", "x"))), df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isinSet(Set("z", "x"))), + checkAnswer(df.filter($"b".isInCollection(Seq("z", "x"))), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isinSet(Set("z", "y"))), + checkAnswer(df.filter($"b".isInCollection(Seq("z", "y"))), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") val e = intercept[AnalysisException] { - df2.filter($"a".isinSet(Set($"b"))) + df2.filter($"a".isInCollection(Seq($"b"))) } Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") .foreach { s => @@ -440,34 +448,40 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("isinSet: Java Set") { + test("isInCollection: Java Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".isinSet(Set(1, 2).asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, 2).asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, 1).asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - // Auto casting should work with mixture of different types in Set - checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2").asJava)), + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2").asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong).asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, "1").asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq(3, "1").asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"b".isinSet(Set("y", "x").asJava)), + checkAnswer(df.filter($"b".isInCollection(Seq("y", "x").asJava)), df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isinSet(Set("z", "x").asJava)), + checkAnswer(df.filter($"b".isInCollection(Seq("z", "x").asJava)), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isinSet(Set("z", "y").asJava)), + checkAnswer(df.filter($"b".isInCollection(Seq("z", "y").asJava)), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") val e = intercept[AnalysisException] { - df2.filter($"a".isinSet(Set($"b").asJava)) + df2.filter($"a".isInCollection(Seq($"b").asJava)) } Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") .foreach { s =>