Skip to content

Commit

Permalink
[SPARK-32376][SQL] Make unionByName null-filling behavior work with s…
Browse files Browse the repository at this point in the history
…truct columns

### What changes were proposed in this pull request?

SPARK-29358 added support for `unionByName` to work when the two datasets didn't necessarily have the same schema, but it does not work with nested columns like structs. This patch adds the support to work with struct columns.

The behavior before this PR:

```scala
scala> val df1 = spark.range(1).selectExpr("id c0", "named_struct('c', id + 1, 'b', id + 2, 'a', id + 3) c1")
scala> val df2 = spark.range(1).selectExpr("id c0", "named_struct('c', id + 1, 'b', id + 2) c1")
scala> df1.unionByName(df2, true).printSchema
org.apache.spark.sql.AnalysisException: Union can only be performed on tables with the compatible column types. struct<c:bigint,b:bigint> <> struct<c:bigint,b:bigint,a:bigint> at the second column of the second table;;
'Union false, false
:- Project [id#0L AS c0#2L, named_struct(c, (id#0L + cast(1 as bigint)), b, (id#0L + cast(2 as bigint)), a, (id#0L + cast(3 as bigint))) AS c1#3]
:  +- Range (0, 1, step=1, splits=Some(12))
+- Project [c0#8L, c1#9]
   +- Project [id#6L AS c0#8L, named_struct(c, (id#6L + cast(1 as bigint)), b, (id#6L + cast(2 as bigint))) AS c1#9]
      +- Range (0, 1, step=1, splits=Some(12))
```

The behavior after this PR:

```scala
scala> df1.unionByName(df2, true).printSchema
root
 |-- c0: long (nullable = false)
 |-- c1: struct (nullable = false)
 |    |-- a: long (nullable = true)
 |    |-- b: long (nullable = false)
 |    |-- c: long (nullable = false)
scala> df1.unionByName(df2, true).show()
+---+-------------+
| c0|           c1|
+---+-------------+
|  0|    {3, 2, 1}|
|  0|{ null, 2, 1}|
+---+-------------+
```

### Why are the changes needed?

The `allowMissingColumns` of `unionByName` is a feature allowing merging different schema from two datasets when unioning them together. Nested column support makes the feature more general and flexible for usage.

### Does this PR introduce _any_ user-facing change?

Yes, after this change users can union two datasets with different schema with different structs.

### How was this patch tested?

Unit tests.

Closes #29587 from viirya/SPARK-32376.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
viirya committed Oct 16, 2020
1 parent ce6180c commit e574fcd
Show file tree
Hide file tree
Showing 8 changed files with 555 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,188 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.types.UTF8String

/**
* Resolves different children of Union to a common set of columns.
*/
object ResolveUnion extends Rule[LogicalPlan] {
private def unionTwoSides(
/**
* This method sorts columns recursively in a struct expression based on column names.
*/
private def sortStructFields(expr: Expression): Expression = {
val existingExprs = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
case (name, i) =>
val fieldExpr = GetStructField(KnownNotNull(expr), i)
if (fieldExpr.dataType.isInstanceOf[StructType]) {
(name, sortStructFields(fieldExpr))
} else {
(name, fieldExpr)
}
}.sortBy(_._1).flatMap(pair => Seq(Literal(pair._1), pair._2))

val newExpr = CreateNamedStruct(existingExprs)
if (expr.nullable) {
If(IsNull(expr), Literal(null, newExpr.dataType), newExpr)
} else {
newExpr
}
}

/**
* Assumes input expressions are field expression of `CreateNamedStruct`. This method
* sorts the expressions based on field names.
*/
private def sortFieldExprs(fieldExprs: Seq[Expression]): Seq[Expression] = {
fieldExprs.grouped(2).map { e =>
Seq(e.head, e.last)
}.toSeq.sortBy { pair =>
assert(pair.head.isInstanceOf[Literal])
pair.head.eval().asInstanceOf[UTF8String].toString
}.flatten
}

/**
* This helper method sorts fields in a `UpdateFields` expression by field name.
*/
private def sortStructFieldsInWithFields(expr: Expression): Expression = expr transformUp {
case u: UpdateFields if u.resolved =>
u.evalExpr match {
case i @ If(IsNull(_), _, CreateNamedStruct(fieldExprs)) =>
val sorted = sortFieldExprs(fieldExprs)
val newStruct = CreateNamedStruct(sorted)
i.copy(trueValue = Literal(null, newStruct.dataType), falseValue = newStruct)
case CreateNamedStruct(fieldExprs) =>
val sorted = sortFieldExprs(fieldExprs)
val newStruct = CreateNamedStruct(sorted)
newStruct
case other =>
throw new IllegalStateException(s"`UpdateFields` has incorrect expression: $other. " +
"Please file a bug report with this error message, stack trace, and the query.")
}
}

def simplifyWithFields(expr: Expression): Expression = {
expr.transformUp {
case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) =>
UpdateFields(struct, fieldOps1 ++ fieldOps2)
}
}

/**
* Adds missing fields recursively into given `col` expression, based on the target `StructType`.
* This is called by `compareAndAddFields` when we find two struct columns with same name but
* different nested fields. This method will find out the missing nested fields from `col` to
* `target` struct and add these missing nested fields. Currently we don't support finding out
* missing nested fields of struct nested in array or struct nested in map.
*/
private def addFields(col: NamedExpression, target: StructType): Expression = {
assert(col.dataType.isInstanceOf[StructType], "Only support StructType.")

val resolver = SQLConf.get.resolver
val missingFieldsOpt =
StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver)

// We need to sort columns in result, because we might add another column in other side.
// E.g., we want to union two structs "a int, b long" and "a int, c string".
// If we don't sort, we will have "a int, b long, c string" and
// "a int, c string, b long", which are not compatible.
if (missingFieldsOpt.isEmpty) {
sortStructFields(col)
} else {
missingFieldsOpt.map { s =>
val struct = addFieldsInto(col, s.fields)
// Combines `WithFields`s to reduce expression tree.
val reducedStruct = simplifyWithFields(struct)
val sorted = sortStructFieldsInWithFields(reducedStruct)
sorted
}.get
}
}

/**
* Adds missing fields recursively into given `col` expression. The missing fields are given
* in `fields`. For example, given `col` as "z struct<z:int, y:int>, x int", and `fields` is
* "z struct<w:long>, w string". This method will add a nested `z.w` field and a top-level
* `w` field to `col` and fill null values for them. Note that because we might also add missing
* fields at other side of Union, we must make sure corresponding attributes at two sides have
* same field order in structs, so when we adding missing fields, we will sort the fields based on
* field names. So the data type of returned expression will be
* "w string, x int, z struct<w:long, y:int, z:int>".
*/
private def addFieldsInto(
col: Expression,
fields: Seq[StructField]): Expression = {
fields.foldLeft(col) { case (currCol, field) =>
field.dataType match {
case st: StructType =>
val resolver = SQLConf.get.resolver
val colField = currCol.dataType.asInstanceOf[StructType]
.find(f => resolver(f.name, field.name))
if (colField.isEmpty) {
// The whole struct is missing. Add a null.
UpdateFields(currCol, field.name, Literal(null, st))
} else {
UpdateFields(currCol, field.name,
addFieldsInto(ExtractValue(currCol, Literal(field.name), resolver), st.fields))
}
case dt =>
UpdateFields(currCol, field.name, Literal(null, dt))
}
}
}

/**
* This method will compare right to left plan's outputs. If there is one struct attribute
* at right side has same name with left side struct attribute, but two structs are not the
* same data type, i.e., some missing (nested) fields at right struct attribute, then this
* method will try to add missing (nested) fields into the right attribute with null values.
*/
private def compareAndAddFields(
left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = {
val resolver = SQLConf.get.resolver
val leftOutputAttrs = left.output
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val aliased = mutable.ArrayBuffer.empty[Attribute]

val rightProjectList = leftOutputAttrs.map { lattr =>
rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse {
val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }
if (found.isDefined) {
val foundAttr = found.get
val foundDt = foundAttr.dataType
(foundDt, lattr.dataType) match {
case (source: StructType, target: StructType)
if allowMissingCol && !source.sameType(target) =>
// Having an output with same name, but different struct type.
// We need to add missing fields. Note that if there are deeply nested structs such as
// nested struct of array in struct, we don't support to add missing deeply nested field
// like that. We will sort columns in the struct expression to make sure two sides of
// union have consistent schema.
aliased += foundAttr
Alias(addFields(foundAttr, target), foundAttr.name)()
case _ =>
// We don't need/try to add missing fields if:
// 1. The attributes of left and right side are the same struct type
// 2. The attributes are not struct types. They might be primitive types, or array, map
// types. We don't support adding missing fields of nested structs in array or map
// types now.
// 3. `allowMissingCol` is disabled.
foundAttr
}
} else {
if (allowMissingCol) {
Alias(Literal(null, lattr.dataType), lattr.name)()
} else {
Expand All @@ -50,18 +209,29 @@ object ResolveUnion extends Rule[LogicalPlan] {
}
}

(rightProjectList, aliased.toSeq)
}

private def unionTwoSides(
left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol)

// Delegates failure checks to `CheckAnalysis`
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList)
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased)
val rightChild = Project(rightProjectList ++ notFoundAttrs, right)

// Builds a project for `logicalPlan` based on `right` output names, if allowing
// missing columns.
val leftChild = if (allowMissingCol) {
val missingAttrs = notFoundAttrs.map { attr =>
Alias(Literal(null, attr.dataType), attr.name)()
}
if (missingAttrs.nonEmpty) {
Project(leftOutputAttrs ++ missingAttrs, left)
// Add missing (nested) fields to left plan.
val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol)
if (leftProjectList.map(_.toAttribute) != left.output) {
Project(leftProjectList, left)
} else {
left
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -661,3 +662,52 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat
}
}
}

object UpdateFields {
private def nameParts(fieldName: String): Seq[String] = {
require(fieldName != null, "fieldName cannot be null")

if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
}

/**
* Adds/replaces field of `StructType` into `col` expression by name.
*/
def apply(col: Expression, fieldName: String, expr: Expression): UpdateFields = {
updateFieldsHelper(col, nameParts(fieldName), name => WithField(name, expr))
}

/**
* Drops fields of `StructType` in `col` expression by name.
*/
def apply(col: Expression, fieldName: String): UpdateFields = {
updateFieldsHelper(col, nameParts(fieldName), name => DropField(name))
}

private def updateFieldsHelper(
structExpr: Expression,
namePartsRemaining: Seq[String],
valueFunc: String => StructFieldsOperation) : UpdateFields = {
val fieldName = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
UpdateFields(structExpr, valueFunc(fieldName) :: Nil)
} else {
val newStruct = if (structExpr.resolved) {
val resolver = SQLConf.get.resolver
ExtractValue(structExpr, Literal(fieldName), resolver)
} else {
UnresolvedExtractValue(structExpr, Literal(fieldName))
}

val newValue = updateFieldsHelper(
structExpr = newStruct,
namePartsRemaining = namePartsRemaining.tail,
valueFunc = valueFunc)
UpdateFields(structExpr, WithField(fieldName, newValue) :: Nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
s"$child.${name.getOrElse(fieldName)}"
}

def extractFieldName: String = name.getOrElse(childSchema(ordinal).name)

override def sql: String =
child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
child.sql + s".${quoteIdentifier(extractFieldName)}"

protected override def nullSafeEval(input: Any): Any =
input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,4 +641,39 @@ object StructType extends AbstractDataType {
fields.foreach(s => map.put(s.name, s))
map
}

/**
* Returns a `StructType` that contains missing fields recursively from `source` to `target`.
* Note that this doesn't support looking into array type and map type recursively.
*/
def findMissingFields(
source: StructType,
target: StructType,
resolver: Resolver): Option[StructType] = {
def bothStructType(dt1: DataType, dt2: DataType): Boolean =
dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType]

val newFields = mutable.ArrayBuffer.empty[StructField]

target.fields.foreach { field =>
val found = source.fields.find(f => resolver(field.name, f.name))
if (found.isEmpty) {
// Found a missing field in `source`.
newFields += field
} else if (bothStructType(found.get.dataType, field.dataType) &&
!found.get.dataType.sameType(field.dataType)) {
// Found a field with same name, but different data type.
findMissingFields(found.get.dataType.asInstanceOf[StructType],
field.dataType.asInstanceOf[StructType], resolver).map { missingType =>
newFields += found.get.copy(dataType = missingType)
}
}
}

if (newFields.isEmpty) {
None
} else {
Some(StructType(newFields.toSeq))
}
}
}
Loading

0 comments on commit e574fcd

Please sign in to comment.