Skip to content

Commit

Permalink
[SPARK-40470][SQL] Handle GetArrayStructFields and GetMapValue in "ar…
Browse files Browse the repository at this point in the history
…rays_zip" function

### What changes were proposed in this pull request?

This is a follow-up for #37833.

The PR fixes column names in `arrays_zip` function for the cases when `GetArrayStructFields` and `GetMapValue` expressions are used (see unit tests for more details).

Before the patch, the column names would be indexes or an AnalysisException would be thrown in the case of `GetArrayStructFields` example.

### Why are the changes needed?

Fixes an inconsistency issue in Spark 3.2 and onwards where the fields would be labeled as indexes instead of column names.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

I added unit tests that reproduce the issue and confirmed that the patch fixes them.

Closes #37911 from sadikovi/SPARK-40470.

Authored-by: Ivan Sadikov <ivan.sadikov@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 9b0f979)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
sadikovi authored and HyukjinKwon committed Sep 16, 2022
1 parent 8068bd3 commit 1b84e44
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
Expand Up @@ -191,7 +191,9 @@ case class ArraysZip(children: Seq[Expression], names: Seq[Expression])
case (u: UnresolvedAttribute, _) => Literal(u.nameParts.last)
case (e: NamedExpression, _) if e.resolved => Literal(e.name)
case (e: NamedExpression, _) => NamePlaceholder
case (e: GetStructField, _) => Literal(e.extractFieldName)
case (g: GetStructField, _) => Literal(g.extractFieldName)
case (g: GetArrayStructFields, _) => Literal(g.field.name)
case (g: GetMapValue, _) => Literal(g.key)
case (_, idx) => Literal(idx.toString)
})
}
Expand Down
Expand Up @@ -641,6 +641,51 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
assert(fieldNames.toSeq === Seq("arr_1", "arr_2", "arr_3"))
}

test("SPARK-40470: array_zip should return field names in GetArrayStructFields") {
val df = spark.read.json(Seq(
"""
{
"arr": [
{
"obj": {
"nested": {
"field1": [1],
"field2": [2]
}
}
}
]
}
""").toDS())

val res = df
.selectExpr("arrays_zip(arr.obj.nested.field1, arr.obj.nested.field2) as arr")
.select(col("arr.field1"), col("arr.field2"))

val fieldNames = res.schema.fieldNames
assert(fieldNames.toSeq === Seq("field1", "field2"))

checkAnswer(res, Row(Seq(Seq(1)), Seq(Seq(2))) :: Nil)
}

test("SPARK-40470: arrays_zip should return field names in GetMapValue") {
val df = spark.sql("""
select
map(
'arr_1', array(1, 2),
'arr_2', array(3, 4)
) as map_obj
""")

val res = df.selectExpr("arrays_zip(map_obj.arr_1, map_obj.arr_2) as arr")

val fieldNames = res.schema.head.dataType.asInstanceOf[ArrayType]
.elementType.asInstanceOf[StructType].fieldNames
assert(fieldNames.toSeq === Seq("arr_1", "arr_2"))

checkAnswer(res, Row(Seq(Row(1, 3), Row(2, 4))))
}

def testSizeOfMap(sizeOfNull: Any): Unit = {
val df = Seq(
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),
Expand Down

0 comments on commit 1b84e44

Please sign in to comment.