diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index fc04ed0d8b4c8..d59d13d49cef4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1123,7 +1123,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val b = AttributeReference("b", IntegerType, nullable = true)() val array = CreateArray(a :: b :: Nil) assert(!ElementAt(array, Literal(1)).nullable) + assert(!ElementAt(array, Literal(-2)).nullable) assert(ElementAt(array, Literal(2)).nullable) + assert(ElementAt(array, Literal(-1)).nullable) assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) @@ -1141,18 +1143,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val inputArray1ContainsNull = c.nullable val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) assert(!ElementAt(stArray1, Literal(1)).nullable) + assert(!ElementAt(stArray1, Literal(-1)).nullable) val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) assert(ElementAt(stArray2, Literal(1)).nullable) + assert(ElementAt(stArray2, Literal(-1)).nullable) val d = AttributeReference("d", structType, nullable = true)() val inputArray2 = CreateArray(c :: d :: Nil) val inputArray2ContainsNull = c.nullable || d.nullable val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) assert(!ElementAt(stArray3, Literal(1)).nullable) + assert(!ElementAt(stArray3, Literal(-2)).nullable) assert(ElementAt(stArray3, Literal(2)).nullable) + assert(ElementAt(stArray3, Literal(-1)).nullable) val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) assert(ElementAt(stArray4, Literal(1)).nullable) + assert(ElementAt(stArray4, Literal(-2)).nullable) assert(ElementAt(stArray4, Literal(2)).nullable) + assert(ElementAt(stArray4, Literal(-1)).nullable) // GetArrayStructFields case invalid indices assert(!ElementAt(stArray3, Literal(0)).nullable)