Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24371] [SQL] Added isInCollection in DataFrame API for Scala and Java. #21416

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -219,7 +218,14 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral
case In(v, list) if list.isEmpty =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this improvement looks reasonable, but can we move them to a separated PR? it's not related to adding isInCollection.

// When v is not nullable, the following expression will be optimized
// to FalseLiteral which is tested in OptimizeInSuite.scala
If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType))
case In(v, Seq(elem @ Literal(_, _))) =>
// `Expression` like `ListQuery` contains subquery which can not
// be converted into `EqualTo`. Only `Literal` is converted for safety.
EqualTo(v, elem)
case expr @ In(v, list) if expr.inSetConvertible =>
val newList = ExpressionSet(list).toSeq
if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ abstract class LogicalPlan
schema.map { field =>
resolve(field.name :: Nil, resolver).map {
case a: AttributeReference => a
case other => sys.error(s"can not handle nested schema yet... plan $this")
case _ => sys.error(s"can not handle nested schema yet... plan $this")
}.getOrElse {
throw new AnalysisException(
s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,21 @@ class OptimizeInSuite extends PlanTest {
}
}

test("OptimizedIn test: one element in list gets transformed to EqualTo.") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1))))
.analyze

val optimized = Optimize.execute(originalQuery)
val correctAnswer =
testRelation
.where(EqualTo(UnresolvedAttribute("a"), Literal(1)))
.analyze

comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: In empty list gets transformed to FalseLiteral " +
"when value is not nullable") {
val originalQuery =
Expand All @@ -191,4 +206,21 @@ class OptimizeInSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("OptimizedIn test: In empty list gets transformed to " +
"If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) when value is nullable") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Nil))
.analyze

val optimized = Optimize.execute(originalQuery)
val correctAnswer =
testRelation
.where(If(IsNotNull(UnresolvedAttribute("a")),
Literal(false), Literal.create(null, BooleanType)))
.analyze

comparePlans(optimized, correctAnswer)
}
}
19 changes: 19 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import scala.collection.JavaConverters._
import scala.language.implicitConversions

import org.apache.spark.annotation.InterfaceStability
Expand Down Expand Up @@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging {
@scala.annotation.varargs
def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is not your change, but I think (both here and bellow) something about the automagical type casting thats going on should be in the docstring/scaladoc/javadoc because to me its a little surprising how this will compare integers to strings and silently convert the types including if there are no strings which can be converted to integers. And I'd also include that in the isInCollection docstring/scaladoc/javadoc bellow.

I'd also point out that the result of the conversion needs to be of the same type and not of a sequence of the type (although the error message we get is pretty clear so your call).

Just a suggestion for improvement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 Let's do it in the followup PR. Thanks.


/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the provided collection.
*
* @group expr_ops
* @since 2.4.0
*/
def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*)

/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the provided collection.
*
* @group java_expr_ops
* @since 2.4.0
*/
def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that we need it for sure, but in the past some of our Java APIs have been difficult to call from Java and I think that if were making an API designed to be called from Java it might make sense to have a test case for it in Java. Here I understand it's not super important, so just a suggestion.


/**
* SQL like expression. Returns a boolean column based on a SQL LIKE match.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql

import java.util.Locale

import scala.collection.JavaConverters._

import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
import org.scalatest.Matchers._
Expand Down Expand Up @@ -392,9 +396,97 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

intercept[AnalysisException] {
val e = intercept[AnalysisException] {
df2.filter($"a".isin($"b"))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("isInCollection: Scala Collection") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we simplify the test cases? you are just testing this api as a wrapper. you don't need to run so many queries for type coercion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2))),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

// Auto casting should work with mixture of different types in collections
checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2"))),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, "1"))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

checkAnswer(df.filter($"b".isInCollection(Seq("y", "x"))),
df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isInCollection(Seq("z", "x"))),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isInCollection(Seq("z", "y"))),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))

// Test with different types of collections
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isInCollection(Seq($"b")))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's check the error message to prevent the future regression like raising different AnalysisException.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("isInCollection: Java Collection") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As stated up above, maybe this would make sense to do in Java, but your call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with you that we should have tests natively in Java instead of converting the types to Java in Scala and hope the best that it will work in Java. Let's do it in the followup PR.

val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing here. just run a single test case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

// Auto casting should work with mixture of different types in collections
checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2").asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, "1").asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

checkAnswer(df.filter($"b".isInCollection(Seq("y", "x").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isInCollection(Seq("z", "x").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isInCollection(Seq("z", "y").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))

// Test with different types of collections
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isInCollection(Seq($"b").asJava))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("&&") {
Expand Down