diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5cdd3c7eb62d1..fff072c831909 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1420,7 +1420,9 @@ case class Reverse(child: Expression) group = "array_funcs", since = "1.5.0") case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Predicate + extends BinaryExpression + with ImplicitCastInputTypes + with Predicate with QueryErrorsBase { @transient private lazy val ordering: Ordering[Any] = @@ -1472,50 +1474,51 @@ case class ArrayContains(left: Expression, right: Expression) left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull } - override def nullSafeEval(arr: Any, value: Any): Any = { - var hasNull = false - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == null) { - hasNull = true - } else if (ordering.equiv(v, value)) { + override def eval(input: InternalRow): Any = { + val array = left.eval(input) + val value = right.eval(input) + val arrayData = array.asInstanceOf[ArrayData] + if (arrayData == null) return null + arrayData.foreach(right.dataType, (_, v) => + if (v == null && value == null) { return true + } else if (v != null && value != null) { + if (ordering.equiv(v, value)) { + return true + } } ) - if (hasNull) { - null - } else { - false - } + false } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (arr, value) => { - val i = ctx.freshName("i") - val getValue = CodeGenerator.getValue(arr, right.dataType, i) - val loopBodyCode = if (nullable) { - s""" - |if ($arr.isNullAt($i)) { - | ${ev.isNull} = true; - |} else if (${ctx.genEqual(right.dataType, value, getValue)}) { - | ${ev.isNull} = false; - | ${ev.value} = true; - | break; - |} - """.stripMargin - } else { - s""" - |if (${ctx.genEqual(right.dataType, value, getValue)}) { - | ${ev.value} = true; - | break; - |} - """.stripMargin - } - s""" - |for (int $i = 0; $i < $arr.numElements(); $i ++) { - | $loopBodyCode + val i = ctx.freshName("i") + val arrayEval = left.genCode(ctx) + val valueEval = right.genCode(ctx) + val getValue = CodeGenerator.getValue(arrayEval.value, right.dataType, i) + ev.copy(code = + code""" + |${arrayEval.code} + |${valueEval.code} + |boolean ${ev.isNull} = false; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (${arrayEval.value} == null) { + | ${ev.isNull} = true; + |} else { + | for (int $i = 0; $i < ${arrayEval.value}.numElements(); $i++) { + | if (${arrayEval.value}.isNullAt($i) && ${valueEval.isNull}) { + | ${ev.value} = true; + | break; + | } else if (!${arrayEval.value}.isNullAt($i) && !${valueEval.isNull}) { + | if (${ctx.genEqual(right.dataType, valueEval.value, getValue)}) { + | ${ev.value} = true; + | break; + | } + | } + | } |} """.stripMargin - }) + ) } override def prettyName: String = "array_contains" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 55148978fa005..7e632e7f10aa0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -131,7 +131,7 @@ class CollectionExpressionsSuite val m1 = Literal.create(null, MapType(StringType, StringType)) checkEvaluation(ArrayContains(MapKeys(m0), Literal("a")), true) checkEvaluation(ArrayContains(MapKeys(m0), Literal("c")), false) - checkEvaluation(ArrayContains(MapKeys(m0), Literal(null, StringType)), null) + checkEvaluation(ArrayContains(MapKeys(m0), Literal(null, StringType)), false) checkEvaluation(ArrayContains(MapKeys(m1), Literal("a")), null) } @@ -591,15 +591,15 @@ class CollectionExpressionsSuite checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) - checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) + checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), false) checkEvaluation(ArrayContains(a5, Literal(1)), true) checkEvaluation(ArrayContains(a1, Literal("")), true) - checkEvaluation(ArrayContains(a1, Literal("a")), null) - checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContains(a1, Literal("a")), false) + checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), true) - checkEvaluation(ArrayContains(a2, Literal(1L)), null) - checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) + checkEvaluation(ArrayContains(a2, Literal(1L)), false) + checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), true) checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) @@ -623,8 +623,8 @@ class CollectionExpressionsSuite checkEvaluation(ArrayContains(b0, be), true) checkEvaluation(ArrayContains(b1, be), false) - checkEvaluation(ArrayContains(b0, nullBinary), null) - checkEvaluation(ArrayContains(b2, be), null) + checkEvaluation(ArrayContains(b0, nullBinary), false) + checkEvaluation(ArrayContains(b2, be), false) checkEvaluation(ArrayContains(b3, be), true) // complex data types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 016803635ff60..b5fe3e853d463 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1785,6 +1785,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ), queryContext = Array(ExpectedContext("", "", 0, 32, "array_contains('a string', 'foo')")) ) + + val schema = StructType(Seq( + StructField("a", ArrayType(IntegerType, containsNull = true)), + StructField("b", IntegerType))) + val data = Seq(Row(Seq[Integer](1, 2, 3, null), null)) + val df1 = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + checkAnswer(df1.select(array_contains(col("a"), col("b"))), Seq(Row(true))) } test("SPARK-29600: ArrayContains function may return incorrect result for DecimalType") {