-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-33391][SQL] element_at with CreateArray not respect one based index. #30296
Changes from 2 commits
10090a7
c0bf2f2
4dac08d
7d09d8c
fc84cac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1966,7 +1966,20 @@ case class ElementAt(left: Expression, right: Expression) | |
} | ||
|
||
override def nullable: Boolean = left.dataType match { | ||
case _: ArrayType => computeNullabilityFromArray(left, right) | ||
case _: ArrayType => | ||
def specialNormalizeIndex: (Int, Int) => Int = { | ||
(arrayLength: Int, index: Int) => { | ||
if (index < 0) { | ||
arrayLength + index | ||
} else if (index == 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but if the passed in index is 0, it will change to -1 and call the following code. it will throw exception, but the old behavior is return a default true.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am just try to follow the old behavior. |
||
// make it default TRUE. | ||
arrayLength | ||
} else { | ||
index - 1 | ||
} | ||
} | ||
} | ||
computeNullabilityFromArray(left, right, normalizeIndex = specialNormalizeIndex) | ||
case _: MapType => true | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1122,9 +1122,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
val a = AttributeReference("a", IntegerType, nullable = false)() | ||
val b = AttributeReference("b", IntegerType, nullable = true)() | ||
val array = CreateArray(a :: b :: Nil) | ||
assert(!ElementAt(array, Literal(0)).nullable) | ||
assert(ElementAt(array, Literal(1)).nullable) | ||
assert(!ElementAt(array, Subtract(Literal(2), Literal(2))).nullable) | ||
assert(!ElementAt(array, Literal(1)).nullable) | ||
assert(ElementAt(array, Literal(2)).nullable) | ||
assert(!ElementAt(array, Subtract(Literal(2), Literal(1))).nullable) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's test valid negative ordinals. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
assert(ElementAt(array, AttributeReference("ordinal", IntegerType)()).nullable) | ||
|
||
// GetArrayStructFields case | ||
|
@@ -1135,19 +1135,19 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
val inputArray1 = CreateArray(c :: Nil) | ||
val inputArray1ContainsNull = c.nullable | ||
val stArray1 = GetArrayStructFields(inputArray1, f1, 0, 2, inputArray1ContainsNull) | ||
assert(!ElementAt(stArray1, Literal(0)).nullable) | ||
assert(!ElementAt(stArray1, Literal(1)).nullable) | ||
val stArray2 = GetArrayStructFields(inputArray1, f2, 1, 2, inputArray1ContainsNull) | ||
assert(ElementAt(stArray2, Literal(0)).nullable) | ||
assert(ElementAt(stArray2, Literal(1)).nullable) | ||
|
||
val d = AttributeReference("d", structType, nullable = true)() | ||
val inputArray2 = CreateArray(c :: d :: Nil) | ||
val inputArray2ContainsNull = c.nullable || d.nullable | ||
val stArray3 = GetArrayStructFields(inputArray2, f1, 0, 2, inputArray2ContainsNull) | ||
assert(!ElementAt(stArray3, Literal(0)).nullable) | ||
assert(ElementAt(stArray3, Literal(1)).nullable) | ||
assert(!ElementAt(stArray3, Literal(1)).nullable) | ||
assert(ElementAt(stArray3, Literal(2)).nullable) | ||
val stArray4 = GetArrayStructFields(inputArray2, f2, 1, 2, inputArray2ContainsNull) | ||
assert(ElementAt(stArray4, Literal(0)).nullable) | ||
assert(ElementAt(stArray4, Literal(1)).nullable) | ||
assert(ElementAt(stArray4, Literal(2)).nullable) | ||
} | ||
|
||
test("Concat") { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1401,6 +1401,40 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { | |
assert(e3.message.contains(errorMsg3)) | ||
} | ||
|
||
test("SPARK-33391: element_at with CreateArray") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems an overkill to have end-to-end test for it. How about we just add more tests in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||
// element_at should use one-based index and support negative index. | ||
// valid index for array(1, 2, 3) should be 1,2,3,-1,-2,-3 | ||
var df = OneRowRelation().selectExpr("element_at(array(1, 2, 3), 1)") | ||
assert(!df.schema.head.nullable) | ||
checkAnswer( | ||
df, | ||
Seq(Row(1)) | ||
) | ||
|
||
df = OneRowRelation().selectExpr("element_at(array(1, 2, 3), -1)") | ||
assert(!df.schema.head.nullable) | ||
checkAnswer( | ||
df, | ||
Seq(Row(3)) | ||
) | ||
|
||
df = OneRowRelation().selectExpr("element_at(array(1, 2, 3), 3)") | ||
assert(!df.schema.head.nullable) | ||
checkAnswer( | ||
df, | ||
Seq(Row(3)) | ||
) | ||
|
||
// 0 is not a valid index, return default nullable which is 'TRUE'. | ||
df = OneRowRelation().selectExpr("element_at(array(1, 2, 3), 0)") | ||
assert(df.schema.head.nullable) | ||
|
||
val ex = intercept[ArrayIndexOutOfBoundsException] { | ||
df.collect() | ||
} | ||
assert(ex.getMessage.contains("SQL array indices start at 1")) | ||
} | ||
|
||
test("array_union functions") { | ||
val df1 = Seq((Array(1, 2, 3), Array(4, 2))).toDF("a", "b") | ||
val ans1 = Row(Seq(1, 2, 3, 4)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can still be negative and fail, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling nullable will not get exception or failed, if it's out of bounds, it's just returning a default true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, if the passing index is negative and arrayLength + index still < 0, it will still failed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to cover the arrayLength + index still < 0 inside this specialNormalizeIndex ?