Skip to content

Commit

Permalink
[SPARK-36673][SQL] Fix incorrect schema of nested types of union
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This patch proposes to fix incorrect schema of `union`.

### Why are the changes needed?

The current `union` result of nested struct columns is incorrect. By definition of `union` API, it should resolve columns by position, not by name. Right now when determining the `output` (aka. the schema) of union plan, we use `merge` API which actually merges two structs (simply think it as concatenate fields from two structs if not overlapping). The merging behavior doesn't match the `union` definition.

So currently we get incorrect schema but the query result is correct. We should fix the incorrect schema.

### Does this PR introduce _any_ user-facing change?

Yes, fixing a bug of incorrect schema.

### How was this patch tested?

Added unit test.

Closes #34025 from viirya/SPARK-36673.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
viirya authored and cloud-fan committed Sep 17, 2021
1 parent 651904a commit cdd7ae9
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 36 deletions.
Expand Up @@ -307,7 +307,7 @@ case class Union(
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
Expand Down
Expand Up @@ -559,52 +559,81 @@ object StructType extends AbstractDataType {
case _ => dt
}

/**
* This leverages `merge` to merge data types for UNION operator by specializing
* the handling of struct types to follow UNION semantics.
*/
private[sql] def unionLikeMerge(left: DataType, right: DataType): DataType =
mergeInternal(left, right, (s1: StructType, s2: StructType) => {
val leftFields = s1.fields
val rightFields = s2.fields
require(leftFields.size == rightFields.size, "To merge nullability, " +
"two structs must have same number of fields.")

val newFields = leftFields.zip(rightFields).map {
case (leftField, rightField) =>
leftField.copy(
dataType = unionLikeMerge(leftField.dataType, rightField.dataType),
nullable = leftField.nullable || rightField.nullable)
}.toSeq
StructType(newFields)
})

private[sql] def merge(left: DataType, right: DataType): DataType =
mergeInternal(left, right, (s1: StructType, s2: StructType) => {
val leftFields = s1.fields
val rightFields = s2.fields
val newFields = mutable.ArrayBuffer.empty[StructField]

val rightMapped = fieldsMap(rightFields)
leftFields.foreach {
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
rightMapped.get(leftName)
.map { case rightField @ StructField(rightName, rightType, rightNullable, _) =>
try {
leftField.copy(
dataType = merge(leftType, rightType),
nullable = leftNullable || rightNullable)
} catch {
case NonFatal(e) =>
throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e)
}
}
.orElse {
Some(leftField)
}
.foreach(newFields += _)
}

val leftMapped = fieldsMap(leftFields)
rightFields
.filterNot(f => leftMapped.get(f.name).nonEmpty)
.foreach { f =>
newFields += f
}

StructType(newFields.toSeq)
})

private def mergeInternal(
left: DataType,
right: DataType,
mergeStruct: (StructType, StructType) => StructType): DataType =
(left, right) match {
case (ArrayType(leftElementType, leftContainsNull),
ArrayType(rightElementType, rightContainsNull)) =>
ArrayType(
merge(leftElementType, rightElementType),
mergeInternal(leftElementType, rightElementType, mergeStruct),
leftContainsNull || rightContainsNull)

case (MapType(leftKeyType, leftValueType, leftContainsNull),
MapType(rightKeyType, rightValueType, rightContainsNull)) =>
MapType(
merge(leftKeyType, rightKeyType),
merge(leftValueType, rightValueType),
mergeInternal(leftKeyType, rightKeyType, mergeStruct),
mergeInternal(leftValueType, rightValueType, mergeStruct),
leftContainsNull || rightContainsNull)

case (StructType(leftFields), StructType(rightFields)) =>
val newFields = mutable.ArrayBuffer.empty[StructField]

val rightMapped = fieldsMap(rightFields)
leftFields.foreach {
case leftField @ StructField(leftName, leftType, leftNullable, _) =>
rightMapped.get(leftName)
.map { case rightField @ StructField(rightName, rightType, rightNullable, _) =>
try {
leftField.copy(
dataType = merge(leftType, rightType),
nullable = leftNullable || rightNullable)
} catch {
case NonFatal(e) =>
throw QueryExecutionErrors.failedMergingFieldsError(leftName, rightName, e)
}
}
.orElse {
Some(leftField)
}
.foreach(newFields += _)
}

val leftMapped = fieldsMap(leftFields)
rightFields
.filterNot(f => leftMapped.get(f.name).nonEmpty)
.foreach { f =>
newFields += f
}

StructType(newFields.toSeq)
case (s1: StructType, s2: StructType) => mergeStruct(s1, s2)

case (DecimalType.Fixed(leftPrecision, leftScale),
DecimalType.Fixed(rightPrecision, rightScale)) =>
Expand Down
Expand Up @@ -684,7 +684,7 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
Expand Down
Expand Up @@ -1018,6 +1018,64 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
unionDF = df1.unionByName(df2)
checkAnswer(unionDF, expected)
}

test("SPARK-36673: Only merge nullability for Unions of struct") {
val df1 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS INNER")))
val df2 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS inner")))

val union1 = df1.union(df2)
val union2 = df1.unionByName(df2)

val schema = StructType(Seq(StructField("id", LongType, false),
StructField("nested", StructType(Seq(StructField("INNER", LongType, false))), false)))

Seq(union1, union2).foreach { df =>
assert(df.schema == schema)
assert(df.queryExecution.optimizedPlan.schema == schema)
assert(df.queryExecution.executedPlan.schema == schema)

checkAnswer(df, Row(0, Row(0)) :: Row(1, Row(5)) :: Row(0, Row(0)) :: Row(1, Row(5)) :: Nil)
checkAnswer(df.select("nested.*"), Row(0) :: Row(5) :: Row(0) :: Row(5) :: Nil)
}
}

test("SPARK-36673: Only merge nullability for unionByName of struct") {
val df1 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS INNER")))
val df2 = spark.range(2).withColumn("nested", struct(expr("id * 5 AS inner")))

val df = df1.unionByName(df2)

val schema = StructType(Seq(StructField("id", LongType, false),
StructField("nested", StructType(Seq(StructField("INNER", LongType, false))), false)))

assert(df.schema == schema)
assert(df.queryExecution.optimizedPlan.schema == schema)
assert(df.queryExecution.executedPlan.schema == schema)

checkAnswer(df, Row(0, Row(0)) :: Row(1, Row(5)) :: Row(0, Row(0)) :: Row(1, Row(5)) :: Nil)
checkAnswer(df.select("nested.*"), Row(0) :: Row(5) :: Row(0) :: Row(5) :: Nil)
}

test("SPARK-36673: Union of structs with different orders") {
val df1 = spark.range(2).withColumn("nested",
struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2"))))
val df2 = spark.range(2).withColumn("nested",
struct(expr("id * 5 AS inner2"), struct(expr("id * 10 AS inner1"))))

val err1 = intercept[AnalysisException](df1.union(df2).collect())

assert(err1.message
.contains("Union can only be performed on tables with the compatible column types"))

val df3 = spark.range(2).withColumn("nested",
struct(expr("id * 5 AS inner1"), struct(expr("id * 10 AS inner2").cast("string"))))
val df4 = spark.range(2).withColumn("nested",
struct(expr("id * 5 AS inner2").cast("string"), struct(expr("id * 10 AS inner1"))))

val err2 = intercept[AnalysisException](df3.union(df4).collect())
assert(err2.message
.contains("Union can only be performed on tables with the compatible column types"))
}
}

case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)
Expand Down

0 comments on commit cdd7ae9

Please sign in to comment.