Skip to content

Commit

Permalink
Make unionByName null-filling behavior work with struct columns.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Aug 31, 2020
1 parent b33066f commit 95e8fd4
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,97 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, Literal, NamedExpression, WithFields}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

/**
* Resolves different children of Union to a common set of columns.
*/
object ResolveUnion extends Rule[LogicalPlan] {
private def unionTwoSides(
/**
* Adds missing fields recursively into given `col` expression, based on the target `StructType`.
* For example, given `col` as "a struct<a:int, b:int>, b int" and `target` as
* "a struct<a:int, b:int, c: long>, b int, c string", this method should add `a.c` and `c` to
* `col` expression.
*/
private def addFields(col: NamedExpression, target: StructType): Option[Expression] = {
require(col.dataType.isInstanceOf[StructType], "Only support StructType.")

val resolver = SQLConf.get.resolver
val missingFields =
StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver)
if (missingFields.length == 0) {
None
} else {
Some(addFieldsInto(col, "", missingFields.fields))
}
}

private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = {
var currCol = col
fields.foreach { field =>
field.dataType match {
case dt: AtomicType =>
// We need to sort columns in result, because we might add another column in other side.
// E.g., we want to union two structs "a int, b long" and "a int, c string".
// If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long",
// which are not compatible.
currCol = WithFields(currCol, s"$base${field.name}", Literal(null, dt),
sortColumns = true)
case st: StructType =>
val resolver = SQLConf.get.resolver
val colField = currCol.dataType.asInstanceOf[StructType]
.find(f => resolver(f.name, field.name))
if (colField.isEmpty) {
// The whole struct is missing. Add a null.
currCol = WithFields(currCol, s"$base${field.name}", Literal(null, st),
sortColumns = true)
} else {
currCol = addFieldsInto(currCol, s"$base${field.name}.", st.fields)
}
}
}
currCol
}

private def compareAndAddFields(
left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = {
val resolver = SQLConf.get.resolver
val leftOutputAttrs = left.output
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val aliased = mutable.ArrayBuffer.empty[Attribute]

val rightProjectList = leftOutputAttrs.map { lattr =>
rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse {
val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }
if (found.isDefined) {
val foundDt = found.get.dataType
(foundDt, lattr.dataType) match {
case (source: StructType, target: StructType)
if allowMissingCol && !source.sameType(target) =>
// Having an output with same name, but different struct type.
// We need to add missing fields.
addFields(found.get, target).map { added =>
aliased += found.get
Alias(added, found.get.name)()
}.getOrElse(found.get) // Data type doesn't change. We should add fields at other side.
case _ =>
// Same struct type, or
// unsupported: different types, array or map types, or
// `allowMissingCol` is disabled.
found.get
}
} else {
if (allowMissingCol) {
Alias(Literal(null, lattr.dataType), lattr.name)()
} else {
Expand All @@ -50,21 +118,28 @@ object ResolveUnion extends Rule[LogicalPlan] {
}
}

(rightProjectList, aliased)
}

private def unionTwoSides(
left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol)

// Delegates failure checks to `CheckAnalysis`
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList)
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased)
val rightChild = Project(rightProjectList ++ notFoundAttrs, right)

// Builds a project for `logicalPlan` based on `right` output names, if allowing
// missing columns.
val leftChild = if (allowMissingCol) {
val missingAttrs = notFoundAttrs.map { attr =>
Alias(Literal(null, attr.dataType), attr.name)()
}
if (missingAttrs.nonEmpty) {
Project(leftOutputAttrs ++ missingAttrs, left)
} else {
left
}
// Add missing (nested) fields to left plan.
val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol)
Project(leftProjectList, left)
} else {
left
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -546,7 +547,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression]) extends Unevaluable {
valExprs: Seq[Expression],
sortColumns: Boolean = false) extends Unevaluable {

assert(names.length == valExprs.length)

Expand Down Expand Up @@ -585,13 +587,71 @@ case class WithFields(
} else {
resultExprs :+ newExpr
}
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }
}

val expr = CreateNamedStruct(newExprs)
val finalExprs = if (sortColumns) {
newExprs.sortBy(_._1).flatMap { case (name, expr) => Seq(Literal(name), expr) }
} else {
newExprs.flatMap { case (name, expr) => Seq(Literal(name), expr) }
}

val expr = CreateNamedStruct(finalExprs)
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
} else {
expr
}
}
}

object WithFields {
/**
* Adds/replaces field in `StructType` into `col` expression by name.
*/
def apply(col: Expression, fieldName: String, expr: Expression): Expression = {
WithFields(col, fieldName, expr, false)
}

def apply(
col: Expression,
fieldName: String,
expr: Expression,
sortColumns: Boolean): Expression = {
val nameParts = if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
withFieldHelper(col, nameParts, Nil, expr, sortColumns)
}

private def withFieldHelper(
struct: Expression,
namePartsRemaining: Seq[String],
namePartsDone: Seq[String],
value: Expression,
sortColumns: Boolean) : WithFields = {
val name = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
WithFields(struct, name :: Nil, value :: Nil, sortColumns)
} else {
val newNamesRemaining = namePartsRemaining.tail
val newNamesDone = namePartsDone :+ name

val newStruct = if (struct.resolved) {
val resolver = SQLConf.get.resolver
ExtractValue(struct, Literal(name), resolver)
} else {
UnresolvedExtractValue(struct, Literal(name))
}

val newValue = withFieldHelper(
struct = newStruct,
namePartsRemaining = newNamesRemaining,
namePartsDone = newNamesDone,
value = value,
sortColumns = sortColumns)
WithFields(struct, name :: Nil, newValue :: Nil, sortColumns)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
case GetStructField(w @ WithFields(struct, names, valExprs, _), ordinal, maybeName) =>
val name = w.dataType(ordinal).name
val matches = names.zip(valExprs).filter(_._1 == name)
if (matches.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
*/
object CombineWithFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
case WithFields(WithFields(struct, names1, valExprs1, sort1), names2, valExprs2, sort2)
if sort1 == sort2 =>
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,4 +641,30 @@ object StructType extends AbstractDataType {
fields.foreach(s => map.put(s.name, s))
map
}

/**
* Returns a `StructType` that contains missing fields recursively from `source` to `target`.
* Note that this doesn't support looking into array type and map type recursively.
*/
def findMissingFields(source: StructType, target: StructType, resolver: Resolver): StructType = {
def bothStructType(dt1: DataType, dt2: DataType): Boolean =
dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType]

val newFields = mutable.ArrayBuffer.empty[StructField]

target.fields.foreach { field =>
val found = source.fields.find(f => resolver(field.name, f.name))
if (found.isEmpty) {
// Found a missing field in `source`.
newFields += field
} else if (bothStructType(found.get.dataType, field.dataType) &&
!found.get.dataType.sameType(field.dataType)) {
// Found a field with same name, but different data type.
newFields += found.get.copy(dataType =
findMissingFields(found.get.dataType.asInstanceOf[StructType],
field.dataType.asInstanceOf[StructType], resolver))
}
}
StructType(newFields)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.types

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType.fromDDL

class StructTypeSuite extends SparkFunSuite {
Expand Down Expand Up @@ -103,4 +104,30 @@ class StructTypeSuite extends SparkFunSuite {
val interval = "`a` INTERVAL"
assert(fromDDL(interval).toDDL === interval)
}

test("find missing (nested) fields") {
val schema = StructType.fromDDL(
"c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
val resolver = SQLConf.get.resolver

val source1 = StructType.fromDDL("c1 INT")
val missing1 = StructType.fromDDL(
"c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source1, schema, resolver).sameType(missing1))

val source2 = StructType.fromDDL("c1 INT, c3 STRING")
val missing2 = StructType.fromDDL(
"c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source2, schema, resolver).sameType(missing2))

val source3 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT>")
val missing3 = StructType.fromDDL(
"c2 STRUCT<c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source3, schema, resolver).sameType(missing3))

val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c6: INT>>")
val missing4 = StructType.fromDDL(
"c2 STRUCT<c4: STRUCT<c5: INT>>")
assert(StructType.findMissingFields(source4, schema, resolver).sameType(missing4))
}
}
27 changes: 1 addition & 26 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -909,32 +909,7 @@ class Column(val expr: Expression) extends Logging {
require(fieldName != null, "fieldName cannot be null")
require(col != null, "col cannot be null")

val nameParts = if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
withFieldHelper(expr, nameParts, Nil, col.expr)
}

private def withFieldHelper(
struct: Expression,
namePartsRemaining: Seq[String],
namePartsDone: Seq[String],
value: Expression) : WithFields = {
val name = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
WithFields(struct, name :: Nil, value :: Nil)
} else {
val newNamesRemaining = namePartsRemaining.tail
val newNamesDone = namePartsDone :+ name
val newValue = withFieldHelper(
struct = UnresolvedExtractValue(struct, Literal(name)),
namePartsRemaining = newNamesRemaining,
namePartsDone = newNamesDone,
value = value)
WithFields(struct, name :: Nil, newValue :: Nil)
}
WithFields(expr, fieldName, col.expr)
}

/**
Expand Down
Loading

0 comments on commit 95e8fd4

Please sign in to comment.