From 9e0bfc4d7f497861180cc1b0974831e2ec911fd6 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 21 Jul 2015 15:09:18 -0700 Subject: [PATCH 01/25] initial attempt at implementation --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 34 ++++++++++++++++++- .../CollectionFunctionsSuite.scala | 5 +++ .../org/apache/spark/sql/functions.scala | 6 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 11 ++++++ 5 files changed, 56 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index bc0846646174a..acf830ef37c45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -238,6 +238,7 @@ object FunctionRegistry { // collection functions expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6ccb56578f790..e1a863efb8bce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** @@ -115,3 +115,35 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } + +case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { + override def dataType: DataType = BooleanType + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, AnyDataType) + + override def eval(input: InternalRow): Boolean = { + val arr = left.eval(input) + if (arr == null) { + false + } else { + val element = right.eval(input) + arr.asInstanceOf[Seq[Any]].contains(element) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrGen = left.gen(ctx) + val elementGen = right.gen(ctx) + s""" + ${arrGen.code} + boolean ${ev.isNull} = false; + if (${arrGen.isNull}) { + ${ev.primitive} = false; + } else { + ${elementGen.code} + ${ev.primitive} = ${arrGen.primitive}.contains(${elementGen.primitive}); + } + """ + } + + override def prettyName: String = "array_contains" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 2c7e85c446ec6..79b85d849af22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -65,4 +65,9 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } + + test("Array contains") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 35958299076ce..55cac4d75d6e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2218,6 +2218,12 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) + def array_contains(columnName: String, element: Any): Column = array_contains(Column(columnName), element) + + def array_contains(column: Column, element: Any): Column = + ArrayContains(column.expr, Literal(element)) + + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 431dcf7382f16..9a4181db71c01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -344,4 +344,15 @@ class DataFrameFunctionsSuite extends QueryTest { Seq(Row(2), Row(0), Row(3)) ) } + + test("array contains function") { + val df = Seq( + (Array[Int](1, 2), "x"), + (Array[Int](), "y") + ).toDF("a", "b") + checkAnswer( + df.select(array_contains("a", 1)), + Seq(Row(true), Row(false)) + ) + } } From 69c46fb2c17aa0cf52954840f2b8e3de5ba33e03 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 21 Jul 2015 16:27:28 -0700 Subject: [PATCH 02/25] added tests and codegen --- .../expressions/collectionOperations.scala | 12 ++++++++++-- .../expressions/CollectionFunctionsSuite.scala | 7 +++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e1a863efb8bce..1e2d2b912316a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -126,7 +126,11 @@ case class ArrayContains(left: Expression, right: Expression) extends BinaryExpr false } else { val element = right.eval(input) - arr.asInstanceOf[Seq[Any]].contains(element) + if (element == null) { + false + } else { + arr.asInstanceOf[Seq[Any]].contains(element) + } } } @@ -140,7 +144,11 @@ case class ArrayContains(left: Expression, right: Expression) extends BinaryExpr ${ev.primitive} = false; } else { ${elementGen.code} - ${ev.primitive} = ${arrGen.primitive}.contains(${elementGen.primitive}); + if (${elementGen.isNull}) { + ${ev.primitive} = false; + } else { + ${ev.primitive} = ${arrGen.primitive}.contains(${elementGen.primitive}); + } } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 79b85d849af22..6d2342cf2f308 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -68,6 +68,13 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Array contains") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Int](), ArrayType(IntegerType)) + checkEvaluation(ArrayContains(a0, Literal(1)), true) + checkEvaluation(ArrayContains(a0, Literal(0)), false) + checkEvaluation(ArrayContains(a0, Literal(null)), false) + + checkEvaluation(ArrayContains(a1, Literal(1)), false) + checkEvaluation(ArrayContains(a1, Literal(null)), false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9a4181db71c01..ff0ee0295e554 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -354,5 +354,21 @@ class DataFrameFunctionsSuite extends QueryTest { df.select(array_contains("a", 1)), Seq(Row(true), Row(false)) ) + checkAnswer( + df.select(array_contains("a", null)), + Seq(Row(false), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, 1)"), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(null, 1)"), + Seq(Row(false), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, null)"), + Seq(Row(false), Row(false)) + ) } } From 72cb4b1db2308de9c60d67195bbf8eb50c82f24a Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 21 Jul 2015 17:52:55 -0700 Subject: [PATCH 03/25] added checkInputTypes and docs --- .../expressions/collectionOperations.scala | 23 +++++++++++++++---- .../org/apache/spark/sql/functions.scala | 17 +++++++++++--- .../spark/sql/DataFrameFunctionsSuite.scala | 11 +++++---- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 1e2d2b912316a..3c53d2019f2f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -116,20 +116,33 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } -case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { +case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression { override def dataType: DataType = BooleanType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, AnyDataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!left.dataType.isInstanceOf[ArrayType]) { + TypeCheckResult.TypeCheckFailure( + s"type of first input must be an array, not ${left.dataType.simpleString}") + } else if (left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"type of value must match array type " + + s"${left.dataType.asInstanceOf[ArrayType].elementType.simpleString}, not "+ + s"${right.dataType.simpleString}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } override def eval(input: InternalRow): Boolean = { val arr = left.eval(input) if (arr == null) { false } else { - val element = right.eval(input) - if (element == null) { + val value = right.eval(input) + if (value == null) { false } else { - arr.asInstanceOf[Seq[Any]].contains(element) + arr.asInstanceOf[Seq[Any]].contains(value) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 55cac4d75d6e3..e39bd10e6ee3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2218,10 +2218,21 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) - def array_contains(columnName: String, element: Any): Column = array_contains(Column(columnName), element) + /** + * Returns true if the array contain the value + * @group collection_funcs + * @since 1.5.0 + */ + def array_contains(columnName: String, value: Any): Column = + array_contains(Column(columnName), value) - def array_contains(column: Column, element: Any): Column = - ArrayContains(column.expr, Literal(element)) + /** + * Returns true if the array contain the value + * @group collection_funcs + * @since 1.5.0 + */ + def array_contains(column: Column, value: Any): Column = + ArrayContains(column.expr, Literal(value)) ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ff0ee0295e554..e4071044fc6a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -348,23 +348,24 @@ class DataFrameFunctionsSuite extends QueryTest { test("array contains function") { val df = Seq( (Array[Int](1, 2), "x"), - (Array[Int](), "y") + (Array[Int](), "y"), + (null, "z") ).toDF("a", "b") checkAnswer( df.select(array_contains("a", 1)), - Seq(Row(true), Row(false)) + Seq(Row(true), Row(false), Row(false)) ) checkAnswer( df.select(array_contains("a", null)), - Seq(Row(false), Row(false)) + Seq(Row(false), Row(false), Row(false)) ) checkAnswer( df.selectExpr("array_contains(a, 1)"), - Seq(Row(true), Row(false)) + Seq(Row(true), Row(false), Row(false)) ) checkAnswer( df.selectExpr("array_contains(null, 1)"), - Seq(Row(false), Row(false)) + Seq(Row(false), Row(false), Row(false)) ) checkAnswer( df.selectExpr("array_contains(a, null)"), From 4b4425be1657edcdd2998d7fe8f331d94dd62f16 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 21 Jul 2015 18:07:58 -0700 Subject: [PATCH 04/25] changed Arrays in tests to Seqs --- .../spark/sql/DataFrameFunctionsSuite.scala | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e4071044fc6a8..37aacbec59e01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -315,9 +315,9 @@ class DataFrameFunctionsSuite extends QueryTest { test("array size function") { val df = Seq( - (Array[Int](1, 2), "x"), - (Array[Int](), "y"), - (Array[Int](1, 2, 3), "z") + (Seq[Int](1, 2), "x"), + (Seq[Int](), "y"), + (Seq[Int](1, 2, 3), "z") ).toDF("a", "b") checkAnswer( df.select(size($"a")), @@ -347,29 +347,29 @@ class DataFrameFunctionsSuite extends QueryTest { test("array contains function") { val df = Seq( - (Array[Int](1, 2), "x"), - (Array[Int](), "y"), + (Seq[Int](1, 2), "x"), + (Seq[Int](), "y"), (null, "z") ).toDF("a", "b") checkAnswer( df.select(array_contains("a", 1)), Seq(Row(true), Row(false), Row(false)) ) - checkAnswer( - df.select(array_contains("a", null)), - Seq(Row(false), Row(false), Row(false)) - ) + //checkAnswer( + // df.select(array_contains("a", null)), + // Seq(Row(false), Row(false), Row(false)) + //) checkAnswer( df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false), Row(false)) ) - checkAnswer( - df.selectExpr("array_contains(null, 1)"), - Seq(Row(false), Row(false), Row(false)) - ) - checkAnswer( - df.selectExpr("array_contains(a, null)"), - Seq(Row(false), Row(false)) - ) + //checkAnswer( + // df.selectExpr("array_contains(null, 1)"), + // Seq(Row(false), Row(false), Row(false)) + //) + //checkAnswer( + // df.selectExpr("array_contains(a, null)"), + // Seq(Row(false), Row(false)) + //) } } From 9623c6497b7a41ef2254f30a70ab4c6e0b8603e3 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 21 Jul 2015 18:29:39 -0700 Subject: [PATCH 05/25] fixed test --- .../expressions/collectionOperations.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3c53d2019f2f5..2b490abcba8e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -151,18 +151,19 @@ case class ArrayContains(left: Expression, right: Expression) extends BinaryExpr val arrGen = left.gen(ctx) val elementGen = right.gen(ctx) s""" - ${arrGen.code} - boolean ${ev.isNull} = false; - if (${arrGen.isNull}) { - ${ev.primitive} = false; - } else { - ${elementGen.code} - if (${elementGen.isNull}) { + ${arrGen.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + if (${arrGen.isNull}) { ${ev.primitive} = false; } else { - ${ev.primitive} = ${arrGen.primitive}.contains(${elementGen.primitive}); + ${elementGen.code} + if (${elementGen.isNull}) { + ${ev.primitive} = false; + } else { + ${ev.primitive} = ${arrGen.primitive}.contains(${elementGen.primitive}); + } } - } """ } From 33b45aa0f7b1f486fbc21d5c8bba06d8ade7cc94 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 21 Jul 2015 18:54:34 -0700 Subject: [PATCH 06/25] reordered test --- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 37aacbec59e01..eed8568d92b1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -355,15 +355,15 @@ class DataFrameFunctionsSuite extends QueryTest { df.select(array_contains("a", 1)), Seq(Row(true), Row(false), Row(false)) ) - //checkAnswer( - // df.select(array_contains("a", null)), - // Seq(Row(false), Row(false), Row(false)) - //) checkAnswer( df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false), Row(false)) ) //checkAnswer( + // df.select(array_contains("a", null)), + // Seq(Row(false), Row(false), Row(false)) + //) + //checkAnswer( // df.selectExpr("array_contains(null, 1)"), // Seq(Row(false), Row(false), Row(false)) //) From 65b562c5a73aba0a42b05abb4b3bc9a08c841878 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 22 Jul 2015 10:25:29 -0700 Subject: [PATCH 07/25] made array_contains nullable false --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2b490abcba8e7..3152b99573e03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -133,6 +133,8 @@ case class ArrayContains(left: Expression, right: Expression) extends BinaryExpr } } + override def nullable = false + override def eval(input: InternalRow): Boolean = { val arr = left.eval(input) if (arr == null) { From e8a20a90dce5f08d41f0fac61bf191bc9ba57dad Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 22 Jul 2015 10:43:12 -0700 Subject: [PATCH 08/25] added python df (broken atm) --- python/pyspark/sql/functions.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a73ecc7d93361..390c990060670 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -34,6 +34,7 @@ __all__ = [ 'array', + 'array_contains', 'approxCountDistinct', 'bin', 'coalesce', @@ -1016,6 +1017,21 @@ def soundex(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(1.5) +def array_contains(col, value): + """ + Collection function: returns True if the array contains the given value + :param col: name of column containing array + :param value: value to check for in array + + >>> df = sqlContext.createDataFrame([([1, 2, 3],), ([],)], ['data']) + >>> df.select(array_contains(df.data, 1)).collect() + [Row(array_contains(data, 1)=True), Row(array_contains(data, 1)=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) + + class UserDefinedFunction(object): """ User defined function in Python From 12f8795a168f207805cb88574d40cf5b52fa1c2e Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 22 Jul 2015 11:19:39 -0700 Subject: [PATCH 09/25] fix scala style checks --- .../expressions/collectionOperations.scala | 4 ++-- .../spark/sql/DataFrameFunctionsSuite.scala | 20 ++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3152b99573e03..2c882631fed0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -126,14 +126,14 @@ case class ArrayContains(left: Expression, right: Expression) extends BinaryExpr } else if (left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { TypeCheckResult.TypeCheckFailure( s"type of value must match array type " + - s"${left.dataType.asInstanceOf[ArrayType].elementType.simpleString}, not "+ + s"${left.dataType.asInstanceOf[ArrayType].elementType.simpleString}, not " + s"${right.dataType.simpleString}") } else { TypeCheckResult.TypeCheckSuccess } } - override def nullable = false + override def nullable: Boolean = false override def eval(input: InternalRow): Boolean = { val arr = left.eval(input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index eed8568d92b1a..757a4415c9c01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -359,17 +359,13 @@ class DataFrameFunctionsSuite extends QueryTest { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false), Row(false)) ) - //checkAnswer( - // df.select(array_contains("a", null)), - // Seq(Row(false), Row(false), Row(false)) - //) - //checkAnswer( - // df.selectExpr("array_contains(null, 1)"), - // Seq(Row(false), Row(false), Row(false)) - //) - //checkAnswer( - // df.selectExpr("array_contains(a, null)"), - // Seq(Row(false), Row(false)) - //) + // checkAnswer( + // df.select(array_contains("a", null)), + // Seq(Row(false), Row(false), Row(false)) + // ) + // checkAnswer( + // df.selectExpr("array_contains(a, null)"), + // Seq(Row(false), Row(false)) + // ) } } From 28b4f716ac43e30e61cc3a92a99b9c685880fb3f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 22 Jul 2015 23:49:12 -0700 Subject: [PATCH 10/25] fixed bug with type conversions and re-added tests --- .../expressions/collectionOperations.scala | 16 ++------------ .../org/apache/spark/sql/functions.scala | 8 ------- .../spark/sql/DataFrameFunctionsSuite.scala | 22 +++++++++++-------- 3 files changed, 15 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2c882631fed0e..41fa58762bb82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -116,22 +116,10 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } -case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression { +case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { override def dataType: DataType = BooleanType - override def checkInputDataTypes(): TypeCheckResult = { - if (!left.dataType.isInstanceOf[ArrayType]) { - TypeCheckResult.TypeCheckFailure( - s"type of first input must be an array, not ${left.dataType.simpleString}") - } else if (left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"type of value must match array type " + - s"${left.dataType.asInstanceOf[ArrayType].elementType.simpleString}, not " + - s"${right.dataType.simpleString}") - } else { - TypeCheckResult.TypeCheckSuccess - } - } + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, AnyDataType) override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e39bd10e6ee3f..d3bbb57878892 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2218,14 +2218,6 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) - /** - * Returns true if the array contain the value - * @group collection_funcs - * @since 1.5.0 - */ - def array_contains(columnName: String, value: Any): Column = - array_contains(Column(columnName), value) - /** * Returns true if the array contain the value * @group collection_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 757a4415c9c01..c8aa7351a2966 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -352,20 +352,24 @@ class DataFrameFunctionsSuite extends QueryTest { (null, "z") ).toDF("a", "b") checkAnswer( - df.select(array_contains("a", 1)), + df.select(array_contains(df("a"), 1)), Seq(Row(true), Row(false), Row(false)) ) checkAnswer( df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false), Row(false)) ) - // checkAnswer( - // df.select(array_contains("a", null)), - // Seq(Row(false), Row(false), Row(false)) - // ) - // checkAnswer( - // df.selectExpr("array_contains(a, null)"), - // Seq(Row(false), Row(false)) - // ) + checkAnswer( + df.select(array_contains(df("a"), null)), + Seq(Row(false), Row(false), Row(false)) + ) + checkAnswer( + df.selectExpr("array_contains(a, null)"), + Seq(Row(false), Row(false), Row(false)) + ) + checkAnswer( + df.select(array_contains(lit(null), 1)), + Seq(Row(false), Row(false), Row(false)) + ) } } From 2517a5818d31500ce9bbc1b171b29f173869cd0f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Wed, 22 Jul 2015 23:50:14 -0700 Subject: [PATCH 11/25] removed unused import --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 41fa58762bb82..6e17a17ec3962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{ + CodegenFallback, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ From d262e9dbb6b879069afcc93aea02f26819be779c Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 23 Jul 2015 00:30:03 -0700 Subject: [PATCH 12/25] reworked type checking code and added more tests --- .../sql/catalyst/expressions/collectionOperations.scala | 5 ++++- .../catalyst/expressions/CollectionFunctionsSuite.scala | 3 +++ .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 8 ++++++-- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6e17a17ec3962..e250377bf601f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -120,7 +120,10 @@ case class SortArray(base: Expression, ascendingOrder: Expression) case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { override def dataType: DataType = BooleanType - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, AnyDataType) + override def inputTypes: Seq[AbstractDataType] = left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case n @ NullType => Seq(TypeCollection(ArrayType, NullType), AnyDataType) + } override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 6d2342cf2f308..580afe8b61f6c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -69,6 +69,7 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Array contains") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Int](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null), ArrayType(IntegerType)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) @@ -76,5 +77,7 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayContains(a1, Literal(1)), false) checkEvaluation(ArrayContains(a1, Literal(null)), false) + + checkEvaluation(ArrayContains(a2, Literal(null)), false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c8aa7351a2966..d3c0633a76361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -348,8 +348,8 @@ class DataFrameFunctionsSuite extends QueryTest { test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), - (Seq[Int](), "y"), - (null, "z") + (Seq[Int](), "x"), + (null, "x") ).toDF("a", "b") checkAnswer( df.select(array_contains(df("a"), 1)), @@ -371,5 +371,9 @@ class DataFrameFunctionsSuite extends QueryTest { df.select(array_contains(lit(null), 1)), Seq(Row(false), Row(false), Row(false)) ) + checkAnswer( + df.select(array_contains(array(lit(2), lit(null)), 1)), + Seq(Row(false), Row(false), Row(false)) + ) } } From 686e02937a9bd9068e21350c681fd609e249fb5b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 23 Jul 2015 00:32:22 -0700 Subject: [PATCH 13/25] fix scala style --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e250377bf601f..08b7074d5ad66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -117,7 +117,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } -case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { +case class ArrayContains(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + override def dataType: DataType = BooleanType override def inputTypes: Seq[AbstractDataType] = left.dataType match { From 85280272a9635b250a3be5a82665f10b9c2c9b66 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 23 Jul 2015 10:21:24 -0700 Subject: [PATCH 14/25] added more tests --- .../sql/catalyst/expressions/CollectionFunctionsSuite.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 580afe8b61f6c..11e89e71fc0e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -69,7 +69,7 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Array contains") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Int](), ArrayType(IntegerType)) - val a2 = Literal.create(Seq(null), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null), ArrayType(NullType)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index d3c0633a76361..a154e032768db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -375,5 +375,9 @@ class DataFrameFunctionsSuite extends QueryTest { df.select(array_contains(array(lit(2), lit(null)), 1)), Seq(Row(false), Row(false), Row(false)) ) + checkAnswer( + df.select(array_contains(array(lit(2), lit(null)), null)), + Seq(Row(false), Row(false), Row(false)) + ) } } From d3ca01383e95fe7309f57517da7e77259454b196 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 23 Jul 2015 12:55:36 -0700 Subject: [PATCH 15/25] Fixed type checking to match hive behavior, then added tests to insure this --- .../expressions/ExpectsInputTypes.scala | 21 +++++---- .../expressions/collectionOperations.scala | 16 +++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 43 +++++++++++-------- 3 files changed, 51 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 2dcbd4eb15031..80e0e38281adc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -42,16 +42,21 @@ trait ExpectsInputTypes extends Expression { def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - val mismatches = children.zip(inputTypes).zipWithIndex.collect { - case ((child, expected), idx) if !expected.acceptsType(child.dataType) => - s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type." - } + if (children.size == inputTypes.size) { + val mismatches = children.zip(inputTypes).zipWithIndex.collect { + case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." + } - if (mismatches.isEmpty) { - TypeCheckResult.TypeCheckSuccess + if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) + } } else { - TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) + TypeCheckResult.TypeCheckFailure( + s"Length of children types was ${children.size}, but expected to be ${inputTypes.size}") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 08b7074d5ad66..af057f21803d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -122,9 +122,19 @@ case class ArrayContains(left: Expression, right: Expression) override def dataType: DataType = BooleanType - override def inputTypes: Seq[AbstractDataType] = left.dataType match { - case n @ ArrayType(element, _) => Seq(n, element) - case n @ NullType => Seq(TypeCollection(ArrayType, NullType), AnyDataType) + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq() + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq() + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + inputTypes.size match { + case 0 => TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + case _ => super.checkInputDataTypes() + } } override def nullable: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a154e032768db..2b65a0eb0fcee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -348,36 +348,43 @@ class DataFrameFunctionsSuite extends QueryTest { test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), - (Seq[Int](), "x"), - (null, "x") + (Seq[Int](), "x") ).toDF("a", "b") + + // Simple test cases checkAnswer( df.select(array_contains(df("a"), 1)), - Seq(Row(true), Row(false), Row(false)) + Seq(Row(true), Row(false)) ) checkAnswer( df.selectExpr("array_contains(a, 1)"), - Seq(Row(true), Row(false), Row(false)) - ) - checkAnswer( - df.select(array_contains(df("a"), null)), - Seq(Row(false), Row(false), Row(false)) - ) - checkAnswer( - df.selectExpr("array_contains(a, null)"), - Seq(Row(false), Row(false), Row(false)) + Seq(Row(true), Row(false)) ) checkAnswer( - df.select(array_contains(lit(null), 1)), - Seq(Row(false), Row(false), Row(false)) + df.select(array_contains(array(lit(2), lit(null)), 1)), + Seq(Row(false), Row(false)) ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df.select(array_contains(df("a"), null)) + } + intercept[AnalysisException] { + df.selectExpr("array_contains(a, null)") + } + intercept[AnalysisException] { + df.selectExpr("array_contains(null, 1)") + } + + // In hive, if either argument has a matching type has a null value, return false, even if + // the first argument array contains a null and the second argument is null checkAnswer( - df.select(array_contains(array(lit(2), lit(null)), 1)), - Seq(Row(false), Row(false), Row(false)) + df.selectExpr("array_contains(array(array(1), null)[1], 1)"), + Seq(Row(false), Row(false)) ) checkAnswer( - df.select(array_contains(array(lit(2), lit(null)), null)), - Seq(Row(false), Row(false), Row(false)) + df.selectExpr("array_contains(array(0, null), array(1, null)[1])"), + Seq(Row(false), Row(false)) ) } } From 46f9789b43b776d880ac744cd21ea16950cc63ce Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 23 Jul 2015 14:16:00 -0700 Subject: [PATCH 16/25] reverted change --- .../expressions/ExpectsInputTypes.scala | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 80e0e38281adc..abe6457747550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -42,21 +42,16 @@ trait ExpectsInputTypes extends Expression { def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - if (children.size == inputTypes.size) { - val mismatches = children.zip(inputTypes).zipWithIndex.collect { - case ((child, expected), idx) if !expected.acceptsType(child.dataType) => - s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." - } + val mismatches = children.zip(inputTypes).zipWithIndex.collect { + case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." + } - if (mismatches.isEmpty) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) - } + if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure( - s"Length of children types was ${children.size}, but expected to be ${inputTypes.size}") + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) } } } From 308239927c92cc30865db114e7bbefb247cb505d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 27 Jul 2015 12:46:15 -0700 Subject: [PATCH 17/25] fixed unit test --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 390c990060670..eeaad34be68e9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1025,7 +1025,7 @@ def array_contains(col, value): :param value: value to check for in array >>> df = sqlContext.createDataFrame([([1, 2, 3],), ([],)], ['data']) - >>> df.select(array_contains(df.data, 1)).collect() + >>> df.select(array_contains(df.data, 1L)).collect() [Row(array_contains(data, 1)=True), Row(array_contains(data, 1)=False)] """ sc = SparkContext._active_spark_context From 4e7dce35013f3e00db17fecceed570e52bff401d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 27 Jul 2015 14:17:43 -0700 Subject: [PATCH 18/25] added more docs --- python/pyspark/sql/functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index eeaad34be68e9..0bae0da09ceca 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1020,7 +1020,8 @@ def soundex(col): @since(1.5) def array_contains(col, value): """ - Collection function: returns True if the array contains the given value + Collection function: returns True if the array contains the given value. The collection + elements and value must be of the same type. :param col: name of column containing array :param value: value to check for in array From b5ffae81f3a7167c7c92c1955f35a0cacda4c155 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jul 2015 00:00:53 -0700 Subject: [PATCH 19/25] fixed pyspark test --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0bae0da09ceca..1502eeac37c06 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1027,7 +1027,7 @@ def array_contains(col, value): >>> df = sqlContext.createDataFrame([([1, 2, 3],), ([],)], ['data']) >>> df.select(array_contains(df.data, 1L)).collect() - [Row(array_contains(data, 1)=True), Row(array_contains(data, 1)=False)] + [Row(array_contains(data,1)=True), Row(array_contains(data,1)=False)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) From 7a22debf04ecd1f21bd7b74558029185582c45a3 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jul 2015 16:20:47 -0700 Subject: [PATCH 20/25] Changed test to use strings instead of long/ints which are different between python 2 an 3 --- python/pyspark/sql/functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1502eeac37c06..68c78663e8fa5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1025,9 +1025,9 @@ def array_contains(col, value): :param col: name of column containing array :param value: value to check for in array - >>> df = sqlContext.createDataFrame([([1, 2, 3],), ([],)], ['data']) - >>> df.select(array_contains(df.data, 1L)).collect() - [Row(array_contains(data,1)=True), Row(array_contains(data,1)=False)] + >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(array_contains(df.data, "a")).collect() + [Row(array_contains(data,"a")=True), Row(array_contains(data,"a")=False)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) From ffc0591b00aedfc6d54a33d66b35040eff601c9c Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 28 Jul 2015 23:12:24 -0700 Subject: [PATCH 21/25] fixed unit test --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 68c78663e8fa5..3534b439d8ad1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1027,7 +1027,7 @@ def array_contains(col, value): >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(array_contains(df.data, "a")).collect() - [Row(array_contains(data,"a")=True), Row(array_contains(data,"a")=False)] + [Row(array_contains(data,a)=True), Row(array_contains(data,a)=False)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) From 4d5b0ff014ea600e68c1dbe5adaad6dc1daabbb2 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Sun, 2 Aug 2015 22:19:37 -0700 Subject: [PATCH 22/25] added docs and another type check --- python/pyspark/sql/functions.py | 1 + .../expressions/collectionOperations.scala | 17 ++++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3534b439d8ad1..c8ac0bee2d1f7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1022,6 +1022,7 @@ def array_contains(col, value): """ Collection function: returns True if the array contains the given value. The collection elements and value must be of the same type. + :param col: name of column containing array :param value: value to check for in array diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index af057f21803d5..3733611e7af92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -117,6 +117,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def prettyName: String = "sort_array" } +/** + * Checks if the array (left) has the element (right) + */ case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -133,7 +136,11 @@ case class ArrayContains(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { inputTypes.size match { case 0 => TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") - case _ => super.checkInputDataTypes() + case _ => left.dataType match { + case n @ ArrayType(element, _) => super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } } } @@ -160,13 +167,9 @@ case class ArrayContains(left: Expression, right: Expression) ${arrGen.code} boolean ${ev.isNull} = false; boolean ${ev.primitive} = false; - if (${arrGen.isNull}) { - ${ev.primitive} = false; - } else { + if (!${arrGen.isNull}) { ${elementGen.code} - if (${elementGen.isNull}) { - ${ev.primitive} = false; - } else { + if (!${elementGen.isNull}) { ${ev.primitive} = ${arrGen.primitive}.contains(${elementGen.primitive}); } } From e352cf9647b476c4b03f60aca118d2c3fd9dd113 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Sun, 2 Aug 2015 22:31:48 -0700 Subject: [PATCH 23/25] fixed diff from master --- .../spark/sql/catalyst/expressions/ExpectsInputTypes.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index abe6457747550..2dcbd4eb15031 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -44,8 +44,8 @@ trait ExpectsInputTypes extends Expression { override def checkInputDataTypes(): TypeCheckResult = { val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => - s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." + s"argument ${idx + 1} requires ${expected.simpleString} type, " + + s"however, '${child.prettyString}' is of ${child.dataType.simpleString} type." } if (mismatches.isEmpty) { From bc3d1fef8166b5192b3948f3fc28582625bcc45a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 4 Aug 2015 15:56:07 -0700 Subject: [PATCH 24/25] fix array_contains --- python/pyspark/sql/functions.py | 1 - .../expressions/collectionOperations.scala | 40 +++++++++++-------- .../CollectionFunctionsSuite.scala | 6 +-- .../org/apache/spark/sql/functions.scala | 17 ++++---- 4 files changed, 35 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index c8ac0bee2d1f7..14003dea0d824 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -34,7 +34,6 @@ __all__ = [ 'array', - 'array_contains', 'approxCountDistinct', 'bin', 'coalesce', diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3733611e7af92..83a7e11bfaf54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -121,7 +121,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) * Checks if the array (left) has the element (right) */ case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BooleanType @@ -134,13 +134,14 @@ case class ArrayContains(left: Expression, right: Expression) } override def checkInputDataTypes(): TypeCheckResult = { - inputTypes.size match { - case 0 => TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") - case _ => left.dataType match { - case n @ ArrayType(element, _) => super.checkInputDataTypes() - case _ => TypeCheckResult.TypeCheckFailure( - "Arguments must be an array followed by a value of same type as the array members") - } + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeCheckResult.TypeCheckSuccess } } @@ -155,7 +156,7 @@ case class ArrayContains(left: Expression, right: Expression) if (value == null) { false } else { - arr.asInstanceOf[Seq[Any]].contains(value) + arr.asInstanceOf[ArrayData].toArray[Any](right.dataType).contains(value) } } } @@ -163,16 +164,23 @@ case class ArrayContains(left: Expression, right: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val arrGen = left.gen(ctx) val elementGen = right.gen(ctx) + val i = ctx.freshName("i") + val getValue = ctx.getValue(arrGen.primitive, right.dataType, i) s""" - ${arrGen.code} - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = false; - if (!${arrGen.isNull}) { - ${elementGen.code} - if (!${elementGen.isNull}) { - ${ev.primitive} = ${arrGen.primitive}.contains(${elementGen.primitive}); + ${arrGen.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + if (!${arrGen.isNull}) { + ${elementGen.code} + if (!${elementGen.isNull}) { + for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) { + if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) { + ${ev.primitive} = true; + break; + } } } + } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 11e89e71fc0e6..95f0e38212a1a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -68,14 +68,14 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Array contains") { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val a1 = Literal.create(Seq[Int](), ArrayType(IntegerType)) - val a2 = Literal.create(Seq(null), ArrayType(NullType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) checkEvaluation(ArrayContains(a0, Literal(null)), false) - checkEvaluation(ArrayContains(a1, Literal(1)), false) + checkEvaluation(ArrayContains(a1, Literal("")), true) checkEvaluation(ArrayContains(a1, Literal(null)), false) checkEvaluation(ArrayContains(a2, Literal(null)), false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d3bbb57878892..edee60e049cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2184,6 +2184,14 @@ object functions { // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns true if the array contain the value + * @group collection_funcs + * @since 1.5.0 + */ + def array_contains(column: Column, value: Any): Column = + ArrayContains(column.expr, Literal(value)) + /** * Creates a new row for each element in the given array or map column. * @@ -2218,15 +2226,6 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) - /** - * Returns true if the array contain the value - * @group collection_funcs - * @since 1.5.0 - */ - def array_contains(column: Column, value: Any): Column = - ArrayContains(column.expr, Literal(value)) - - ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// From d3c08bc79f9e03afc7cce8ee4adc4ea6c825ba68 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 4 Aug 2015 20:16:21 -0700 Subject: [PATCH 25/25] use foreach() to avoid copy --- .../sql/catalyst/expressions/collectionOperations.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 83a7e11bfaf54..646afa4047d84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -156,7 +156,10 @@ case class ArrayContains(left: Expression, right: Expression) if (value == null) { false } else { - arr.asInstanceOf[ArrayData].toArray[Any](right.dataType).contains(value) + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == value) return true + ) + false } } }