Skip to content

Commit

Permalink
[SPARK-24371][SQL] Added isInCollection in DataFrame API for Scala an…
Browse files Browse the repository at this point in the history
…d Java.

## What changes were proposed in this pull request?

Implemented **`isInCollection `** in DataFrame API for both Scala and Java, so users can do

```scala
val profileDF = Seq(
  Some(1), Some(2), Some(3), Some(4),
  Some(5), Some(6), Some(7), None
).toDF("profileID")

val validUsers: Seq[Any] = Seq(6, 7.toShort, 8L, "3")

val result = profileDF.withColumn("isValid", $"profileID". isInCollection(validUsers))

result.show(10)
"""
+---------+-------+
|profileID|isValid|
+---------+-------+
|        1|  false|
|        2|  false|
|        3|   true|
|        4|  false|
|        5|  false|
|        6|   true|
|        7|   true|
|     null|   null|
+---------+-------+
 """.stripMargin
```
## How was this patch tested?

Several unit tests are added.

Author: DB Tsai <d_tsai@apple.com>

Closes #21416 from dbtsai/optimize-set.
  • Loading branch information
dbtsai committed May 29, 2018
1 parent aca65c6 commit 900bc1f
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ abstract class LogicalPlan
schema.map { field =>
resolve(field.name :: Nil, resolver).map {
case a: AttributeReference => a
case other => sys.error(s"can not handle nested schema yet... plan $this")
case _ => sys.error(s"can not handle nested schema yet... plan $this")
}.getOrElse {
throw new AnalysisException(
s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]")
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import scala.collection.JavaConverters._
import scala.language.implicitConversions

import org.apache.spark.annotation.InterfaceStability
Expand Down Expand Up @@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging {
@scala.annotation.varargs
def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) }

/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the provided collection.
*
* @group expr_ops
* @since 2.4.0
*/
def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*)

/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the provided collection.
*
* @group java_expr_ops
* @since 2.4.0
*/
def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala)

/**
* SQL like expression. Returns a boolean column based on a SQL LIKE match.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql

import java.util.Locale

import scala.collection.JavaConverters._

import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
import org.scalatest.Matchers._
Expand Down Expand Up @@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.filter($"b".isin("z", "y")),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))

// Auto casting should work with mixture of different types in collections
checkAnswer(df.filter($"a".isin(1.toShort, "2")),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isin("3", 2.toLong)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isin(3, "1")),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

intercept[AnalysisException] {
val e = intercept[AnalysisException] {
df2.filter($"a".isin($"b"))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("isInCollection: Scala Collection") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
// Test with different types of collections
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isInCollection(Seq($"b")))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("isInCollection: Java Collection") {
val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
// Test with different types of collections
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)),
df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))

val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")

val e = intercept[AnalysisException] {
df2.filter($"a".isInCollection(Seq($"b").asJava))
}
Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
.foreach { s =>
assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
}
}

test("&&") {
Expand Down

0 comments on commit 900bc1f

Please sign in to comment.