Skip to content

Commit 1b84e44

Browse files
sadikoviHyukjinKwon
authored andcommitted
[SPARK-40470][SQL] Handle GetArrayStructFields and GetMapValue in "arrays_zip" function
### What changes were proposed in this pull request? This is a follow-up for #37833. The PR fixes column names in `arrays_zip` function for the cases when `GetArrayStructFields` and `GetMapValue` expressions are used (see unit tests for more details). Before the patch, the column names would be indexes or an AnalysisException would be thrown in the case of `GetArrayStructFields` example. ### Why are the changes needed? Fixes an inconsistency issue in Spark 3.2 and onwards where the fields would be labeled as indexes instead of column names. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I added unit tests that reproduce the issue and confirmed that the patch fixes them. Closes #37911 from sadikovi/SPARK-40470. Authored-by: Ivan Sadikov <ivan.sadikov@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit 9b0f979) Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 8068bd3 commit 1b84e44

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ case class ArraysZip(children: Seq[Expression], names: Seq[Expression])
191191
case (u: UnresolvedAttribute, _) => Literal(u.nameParts.last)
192192
case (e: NamedExpression, _) if e.resolved => Literal(e.name)
193193
case (e: NamedExpression, _) => NamePlaceholder
194-
case (e: GetStructField, _) => Literal(e.extractFieldName)
194+
case (g: GetStructField, _) => Literal(g.extractFieldName)
195+
case (g: GetArrayStructFields, _) => Literal(g.field.name)
196+
case (g: GetMapValue, _) => Literal(g.key)
195197
case (_, idx) => Literal(idx.toString)
196198
})
197199
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,51 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
641641
assert(fieldNames.toSeq === Seq("arr_1", "arr_2", "arr_3"))
642642
}
643643

644+
test("SPARK-40470: array_zip should return field names in GetArrayStructFields") {
645+
val df = spark.read.json(Seq(
646+
"""
647+
{
648+
"arr": [
649+
{
650+
"obj": {
651+
"nested": {
652+
"field1": [1],
653+
"field2": [2]
654+
}
655+
}
656+
}
657+
]
658+
}
659+
""").toDS())
660+
661+
val res = df
662+
.selectExpr("arrays_zip(arr.obj.nested.field1, arr.obj.nested.field2) as arr")
663+
.select(col("arr.field1"), col("arr.field2"))
664+
665+
val fieldNames = res.schema.fieldNames
666+
assert(fieldNames.toSeq === Seq("field1", "field2"))
667+
668+
checkAnswer(res, Row(Seq(Seq(1)), Seq(Seq(2))) :: Nil)
669+
}
670+
671+
test("SPARK-40470: arrays_zip should return field names in GetMapValue") {
672+
val df = spark.sql("""
673+
select
674+
map(
675+
'arr_1', array(1, 2),
676+
'arr_2', array(3, 4)
677+
) as map_obj
678+
""")
679+
680+
val res = df.selectExpr("arrays_zip(map_obj.arr_1, map_obj.arr_2) as arr")
681+
682+
val fieldNames = res.schema.head.dataType.asInstanceOf[ArrayType]
683+
.elementType.asInstanceOf[StructType].fieldNames
684+
assert(fieldNames.toSeq === Seq("arr_1", "arr_2"))
685+
686+
checkAnswer(res, Row(Seq(Row(1, 3), Row(2, 4))))
687+
}
688+
644689
def testSizeOfMap(sizeOfNull: Any): Unit = {
645690
val df = Seq(
646691
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),

0 commit comments

Comments
 (0)