Skip to content

Commit

Permalink
Addressed feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
dbtsai committed May 26, 2018
1 parent 6ff2806 commit 286a468
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ object OptimizeIn extends Rule[LogicalPlan] {
// When v is not nullable, the following expression will be optimized
// to FalseLiteral which is tested in OptimizeInSuite.scala
If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType))
case In(v, Seq(elem @ Literal(_, _))) => EqualTo(v, elem)
case In(v, Seq(elem @ Literal(_, _))) =>
// `Expression` like `ListQuery` contains subquery which can not
// be converted into `EqualTo`. Only `Literal` is converted for safety.
EqualTo(v, elem)
case expr @ In(v, list) if expr.inSetConvertible =>
val newList = ExpressionSet(list).toSeq
if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) {
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -789,21 +789,21 @@ class Column(val expr: Expression) extends Logging {

/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the provided Set.
* by the provided collection.
*
* @group expr_ops
* @since 2.4.0
*/
def isinSet(values: scala.collection.Set[_]): Column = isin(values.toSeq: _*)
def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*)

/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the provided Set.
* by the provided collection.
*
* @group java_expr_ops
* @since 2.4.0
*/
def isinSet(values: java.util.Set[_]): Column = isinSet(values.asScala)
def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala)

/**
* SQL like expression. Returns a boolean column based on a SQL LIKE match.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,69 +405,83 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
}
}

test("isinSet: Scala Set") {
test("isInCollection: Scala Collection") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
checkAnswer(df.filter($"a".isinSet(Set(1, 2))),
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2))),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, 2))),
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, 1))),
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

// Auto casting should work with mixture of different types in Set
checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2"))),
// Auto casting should work with mixture of different types in collections
checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2"))),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong))),
checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, "1"))),
checkAnswer(df.filter($"a".isInCollection(Seq(3, "1"))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

checkAnswer(df.filter($"b".isinSet(Set("y", "x"))),
checkAnswer(df.filter($"b".isInCollection(Seq("y", "x"))),
df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isinSet(Set("z", "x"))),
checkAnswer(df.filter($"b".isInCollection(Seq("z", "x"))),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isinSet(Set("z", "y"))),
checkAnswer(df.filter($"b".isInCollection(Seq("z", "y"))),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))

// Test with different types of collections
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isinSet(Set($"b")))
df2.filter($"a".isInCollection(Seq($"b")))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("isinSet: Java Set") {
test("isInCollection: Java Collection") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
checkAnswer(df.filter($"a".isinSet(Set(1, 2).asJava)),
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, 2).asJava)),
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, 1).asJava)),
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

// Auto casting should work with mixture of different types in Set
checkAnswer(df.filter($"a".isinSet(Set(1.toShort, "2").asJava)),
// Auto casting should work with mixture of different types in collections
checkAnswer(df.filter($"a".isInCollection(Seq(1.toShort, "2").asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set("3", 2.toLong).asJava)),
checkAnswer(df.filter($"a".isInCollection(Seq("3", 2.toLong).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isinSet(Set(3, "1").asJava)),
checkAnswer(df.filter($"a".isInCollection(Seq(3, "1").asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

checkAnswer(df.filter($"b".isinSet(Set("y", "x").asJava)),
checkAnswer(df.filter($"b".isInCollection(Seq("y", "x").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isinSet(Set("z", "x").asJava)),
checkAnswer(df.filter($"b".isInCollection(Seq("z", "x").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".isinSet(Set("z", "y").asJava)),
checkAnswer(df.filter($"b".isInCollection(Seq("z", "y").asJava)),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))

// Test with different types of collections
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isinSet(Set($"b").asJava))
df2.filter($"a".isInCollection(Seq($"b").asJava))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
Expand Down

0 comments on commit 286a468

Please sign in to comment.