From da103073f847b45d64a1014d99b4a21ccf3f3533 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 16 May 2018 17:21:14 -0700 Subject: [PATCH 1/6] Added isinSet in DataFrame API for Scala and Java. --- .../sql/catalyst/optimizer/expressions.scala | 7 +- .../sql/catalyst/optimizer/subquery.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../catalyst/optimizer/OptimizeInSuite.scala | 32 ++++++++ .../scala/org/apache/spark/sql/Column.scala | 19 +++++ .../spark/sql/ColumnExpressionSuite.scala | 80 ++++++++++++++++++- 6 files changed, 137 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1c0b7bd806801..821e0e9ef4d22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -219,7 +218,11 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral + case In(v, list) if list.isEmpty => + // When v is not nullable, the following expression will be optimized + // to FalseLiteral which is tested in OptimizeInSuite.scala + If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) + case In(v, list) if list.length == 1 => EqualTo(v, list.head) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 709db6d8bec7d..97c7a481014e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -149,7 +149,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case In(value, Seq(ListQuery(sub, conditions, _, _))) => + case EqualTo(value, ListQuery(sub, conditions, _, _)) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e487693927ab6..c486ad700f362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -78,7 +78,7 @@ abstract class LogicalPlan schema.map { field => resolve(field.name :: Nil, resolver).map { case a: AttributeReference => a - case other => sys.error(s"can not handle nested schema yet... plan $this") + case _ => sys.error(s"can not handle nested schema yet... plan $this") }.getOrElse { throw new AnalysisException( s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 478118ed709f7..3735022c81f6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,6 +176,21 @@ class OptimizeInSuite extends PlanTest { } } + test("OptimizedIn test: one element in list gets transformed to EqualTo.") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b")))) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("b"))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -191,4 +206,21 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("OptimizedIn test: In empty list gets transformed to " + + "If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) when value is nullable") { + val originalQuery = + testRelation + .where(In(UnresolvedAttribute("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(If(IsNotNull(UnresolvedAttribute("a")), + Literal(false), Literal.create(null, BooleanType))) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad0efbae89830..5e603f6f9145f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability @@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging { @scala.annotation.varargs def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided Set. + * + * @group expr_ops + * @since 2.4.0 + */ + def isinSet(values: scala.collection.Set[_]): Column = isin(values.toSeq: _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided Set. + * + * @group java_expr_ops + * @since 2.4.0 + */ + def isinSet(values: java.util.Set[_]): Column = isinSet(values.asScala) + /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 7c45be21961d3..69afd0715176a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.util.Locale + +import scala.collection.JavaConverters._ + import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ @@ -392,9 +396,83 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - intercept[AnalysisException] { + val e = intercept[AnalysisException] { df2.filter($"a".isin($"b")) } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isinSet: Scala Set") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + checkAnswer(df.filter($"a".isinSet(Set(1, 2))), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isinSet(Set(3, 2))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isinSet(Set(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + // Auto casting should work with mixture of different types in Set + checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2"))), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isinSet(Set(3, "1"))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + checkAnswer(df.filter($"b".isinSet(Set("y", "x"))), + df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".isinSet(Set("z", "x"))), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".isinSet(Set("z", "y"))), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isinSet(Set($"b"))) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isinSet: Java Set") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + checkAnswer(df.filter($"a".isinSet(Set(1, 2).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isinSet(Set(3, 2).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isinSet(Set(3, 1).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + // Auto casting should work with mixture of different types in Set + checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2").asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isinSet(Set(3, "1").asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + checkAnswer(df.filter($"b".isinSet(Set("y", "x").asJava)), + df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".isinSet(Set("z", "x").asJava)), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".isinSet(Set("z", "y").asJava)), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isinSet(Set($"b").asJava)) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } test("&&") { From 730b19ba8b0e5db3c4bfda150cd156370a321fef Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 24 May 2018 23:40:27 -0700 Subject: [PATCH 2/6] Not optimize In if the list is type of ListQuery --- .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 3 ++- .../org/apache/spark/sql/catalyst/optimizer/subquery.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 821e0e9ef4d22..261ec625c94af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -222,7 +222,8 @@ object OptimizeIn extends Rule[LogicalPlan] { // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case In(v, list) if list.length == 1 => EqualTo(v, list.head) + case In(v, list) if list.length == 1 && !list.head.isInstanceOf[ListQuery] => + EqualTo(v, list.head) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 97c7a481014e4..709db6d8bec7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -149,7 +149,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { newPlan = dedupJoin( Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))) exists - case EqualTo(value, ListQuery(sub, conditions, _, _)) => + case In(value, Seq(ListQuery(sub, conditions, _, _))) => val exists = AttributeReference("exists", BooleanType, nullable = false)() val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled) val newConditions = (inConditions ++ conditions).reduceLeftOption(And) From 6ff280616bab4f38d338b9b3aca38c4affd86cd8 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 24 May 2018 23:57:07 -0700 Subject: [PATCH 3/6] Safer rule --- .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 3 +-- .../apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 261ec625c94af..6af478e390f33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -222,8 +222,7 @@ object OptimizeIn extends Rule[LogicalPlan] { // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case In(v, list) if list.length == 1 && !list.head.isInstanceOf[ListQuery] => - EqualTo(v, list.head) + case In(v, Seq(elem @ Literal(_, _))) => EqualTo(v, elem) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 3735022c81f6b..8349c14c4b8e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -179,13 +179,13 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: one element in list gets transformed to EqualTo.") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) .analyze val optimized = Optimize.execute(originalQuery) val correctAnswer = testRelation - .where(EqualTo(UnresolvedAttribute("a"), UnresolvedAttribute("b"))) + .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) .analyze comparePlans(optimized, correctAnswer) From 286a46849fb8840b6aad243fbc1b8b17b1f6a02b Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 25 May 2018 17:16:00 -0700 Subject: [PATCH 4/6] Addressed feedback --- .../sql/catalyst/optimizer/expressions.scala | 5 +- .../scala/org/apache/spark/sql/Column.scala | 8 +-- .../spark/sql/ColumnExpressionSuite.scala | 62 ++++++++++++------- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 6af478e390f33..d3a829fe511dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -222,7 +222,10 @@ object OptimizeIn extends Rule[LogicalPlan] { // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case In(v, Seq(elem @ Literal(_, _))) => EqualTo(v, elem) + case In(v, Seq(elem @ Literal(_, _))) => + // `Expression` like `ListQuery` contains subquery which can not + // be converted into `EqualTo`. Only `Literal` is converted for safety. + EqualTo(v, elem) case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 5e603f6f9145f..b3e59f53ee3de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -789,21 +789,21 @@ class Column(val expr: Expression) extends Logging { /** * A boolean expression that is evaluated to true if the value of this expression is contained - * by the provided Set. + * by the provided collection. * * @group expr_ops * @since 2.4.0 */ - def isinSet(values: scala.collection.Set[_]): Column = isin(values.toSeq: _*) + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) /** * A boolean expression that is evaluated to true if the value of this expression is contained - * by the provided Set. + * by the provided collection. * * @group java_expr_ops * @since 2.4.0 */ - def isinSet(values: java.util.Set[_]): Column = isinSet(values.asScala) + def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 69afd0715176a..bd58496288d9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -405,34 +405,42 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("isinSet: Scala Set") { + test("isInCollection: Scala Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".isinSet(Set(1, 2))), + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2))), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, 2))), + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, 1))), + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - // Auto casting should work with mixture of different types in Set - checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2"))), + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2"))), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong))), + checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, "1"))), + checkAnswer(df.filter($"a".isInCollection(Seq(3, "1"))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"b".isinSet(Set("y", "x"))), + checkAnswer(df.filter($"b".isInCollection(Seq("y", "x"))), df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isinSet(Set("z", "x"))), + checkAnswer(df.filter($"b".isInCollection(Seq("z", "x"))), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isinSet(Set("z", "y"))), + checkAnswer(df.filter($"b".isInCollection(Seq("z", "y"))), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") val e = intercept[AnalysisException] { - df2.filter($"a".isinSet(Set($"b"))) + df2.filter($"a".isInCollection(Seq($"b"))) } Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") .foreach { s => @@ -440,34 +448,40 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("isinSet: Java Set") { + test("isInCollection: Java Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".isinSet(Set(1, 2).asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, 2).asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, 1).asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - // Auto casting should work with mixture of different types in Set - checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2").asJava)), + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2").asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong).asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isinSet(Set(3, "1").asJava)), + checkAnswer(df.filter($"a".isInCollection(Seq(3, "1").asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - checkAnswer(df.filter($"b".isinSet(Set("y", "x").asJava)), + checkAnswer(df.filter($"b".isInCollection(Seq("y", "x").asJava)), df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isinSet(Set("z", "x").asJava)), + checkAnswer(df.filter($"b".isInCollection(Seq("z", "x").asJava)), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isinSet(Set("z", "y").asJava)), + checkAnswer(df.filter($"b".isInCollection(Seq("z", "y").asJava)), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") val e = intercept[AnalysisException] { - df2.filter($"a".isinSet(Set($"b").asJava)) + df2.filter($"a".isInCollection(Seq($"b").asJava)) } Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") .foreach { s => From 1332406d7f4ca7a9a4a85338f758430ecc334ff8 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sun, 27 May 2018 22:30:28 -0700 Subject: [PATCH 5/6] Removed rules in OptimizeIn into another PR. --- .../sql/catalyst/optimizer/expressions.scala | 9 +----- .../catalyst/optimizer/OptimizeInSuite.scala | 32 ------------------- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index d3a829fe511dc..1d363b8146e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -218,14 +218,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - case In(v, list) if list.isEmpty => - // When v is not nullable, the following expression will be optimized - // to FalseLiteral which is tested in OptimizeInSuite.scala - If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case In(v, Seq(elem @ Literal(_, _))) => - // `Expression` like `ListQuery` contains subquery which can not - // be converted into `EqualTo`. Only `Literal` is converted for safety. - EqualTo(v, elem) + case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 8349c14c4b8e3..478118ed709f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -176,21 +176,6 @@ class OptimizeInSuite extends PlanTest { } } - test("OptimizedIn test: one element in list gets transformed to EqualTo.") { - val originalQuery = - testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1)))) - .analyze - - val optimized = Optimize.execute(originalQuery) - val correctAnswer = - testRelation - .where(EqualTo(UnresolvedAttribute("a"), Literal(1))) - .analyze - - comparePlans(optimized, correctAnswer) - } - test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + "when value is not nullable") { val originalQuery = @@ -206,21 +191,4 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - - test("OptimizedIn test: In empty list gets transformed to " + - "If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) when value is nullable") { - val originalQuery = - testRelation - .where(In(UnresolvedAttribute("a"), Nil)) - .analyze - - val optimized = Optimize.execute(originalQuery) - val correctAnswer = - testRelation - .where(If(IsNotNull(UnresolvedAttribute("a")), - Literal(false), Literal.create(null, BooleanType))) - .analyze - - comparePlans(optimized, correctAnswer) - } } From fed2846fe7c9ca2cb4534b23803cd29d5a18d4f9 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 28 May 2018 23:44:14 -0700 Subject: [PATCH 6/6] Simplify the tests --- .../spark/sql/ColumnExpressionSuite.scala | 52 ++++--------------- 1 file changed, 10 insertions(+), 42 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bd58496288d9b..2182bd7eadd63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -394,6 +394,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isin(1.toShort, "2")), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin("3", 2.toLong)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, "1")), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") val e = intercept[AnalysisException] { @@ -407,29 +415,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("isInCollection: Scala Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") - checkAnswer(df.filter($"a".isInCollection(Seq(1, 2))), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 2))), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + // Test with different types of collections checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - // Auto casting should work with mixture of different types in collections - checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2"))), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong))), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, "1"))), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - checkAnswer(df.filter($"b".isInCollection(Seq("y", "x"))), - df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isInCollection(Seq("z", "x"))), - df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isInCollection(Seq("z", "y"))), - df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) - - // Test with different types of collections checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), @@ -450,29 +438,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("isInCollection: Java Collection") { val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - // Auto casting should work with mixture of different types in collections - checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2").asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong).asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) - checkAnswer(df.filter($"a".isInCollection(Seq(3, "1").asJava)), - df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) - - checkAnswer(df.filter($"b".isInCollection(Seq("y", "x").asJava)), - df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isInCollection(Seq("z", "x").asJava)), - df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) - checkAnswer(df.filter($"b".isInCollection(Seq("z", "y").asJava)), - df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) - - // Test with different types of collections checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)),