Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32376][SQL] Make unionByName null-filling behavior work with struct columns #29587

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,97 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, Literal, NamedExpression, WithFields}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils

/**
* Resolves different children of Union to a common set of columns.
*/
object ResolveUnion extends Rule[LogicalPlan] {
private def unionTwoSides(
/**
* Adds missing fields recursively into given `col` expression, based on the target `StructType`.
* For example, given `col` as "a struct<a:int, b:int>, b int" and `target` as
* "a struct<a:int, b:int, c: long>, b int, c string", this method should add `a.c` and `c` to
* `col` expression.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you describe the negative cases? For instance, could you mention compareAndAddFields briefly and what happens when we meet the same column names with different types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revised the comment.

*/
private def addFields(col: NamedExpression, target: StructType): Option[Expression] = {
require(col.dataType.isInstanceOf[StructType], "Only support StructType.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert instead?


val resolver = SQLConf.get.resolver
val missingFields =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is a bit misleading and I though it's a Seq. How about missingFieldsOpt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good. Fixed.

StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver)
if (missingFields.length == 0) {
None
} else {
Some(addFieldsInto(col, "", missingFields.fields))
}
}

private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = {
var currCol = col
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To remove var here, could we use fields.foldLeft(col) { case (currCol, f) => instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good. rewritten. thanks.

fields.foreach { field =>
field.dataType match {
case dt: AtomicType =>
// We need to sort columns in result, because we might add another column in other side.
// E.g., we want to union two structs "a int, b long" and "a int, c string".
// If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long",
// which are not compatible.
currCol = WithFields(currCol, s"$base${field.name}", Literal(null, dt),
sortColumns = true)
case st: StructType =>
val resolver = SQLConf.get.resolver
val colField = currCol.dataType.asInstanceOf[StructType]
.find(f => resolver(f.name, field.name))
if (colField.isEmpty) {
// The whole struct is missing. Add a null.
currCol = WithFields(currCol, s"$base${field.name}", Literal(null, st),
sortColumns = true)
} else {
currCol = addFieldsInto(currCol, s"$base${field.name}.", st.fields)
}
}
}
currCol
}

private def compareAndAddFields(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we have a rich comment in the function body, could you add a function description to give a general idea?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = {
val resolver = SQLConf.get.resolver
val leftOutputAttrs = left.output
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val aliased = mutable.ArrayBuffer.empty[Attribute]

val rightProjectList = leftOutputAttrs.map { lattr =>
rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse {
val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }
if (found.isDefined) {
dongjoon-hyun marked this conversation as resolved.
Show resolved Hide resolved
val foundDt = found.get.dataType
(foundDt, lattr.dataType) match {
case (source: StructType, target: StructType)
if allowMissingCol && !source.sameType(target) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the logic simpler, could we filter out all the unsupported case (e.g., nested struct in array) here? This is it like this;

          case (source: StructType, target: StructType)
              if allowMissingCol && canMergeSchemas(source, target) =>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure where we can simplify the logic? By adding canMergeSchemas, doesn't it look more complicated?

Copy link
Member

@maropu maropu Sep 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read the comment and I thought first that all the unsupported cases are handled in the line 108-112. But, it also means unsupported cases if addFields returning None? This might be a issue that can be fixed just by improving comments though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I will add more comments explaining this.

// Having an output with same name, but different struct type.
// We need to add missing fields.
addFields(found.get, target).map { added =>
aliased += found.get
Alias(added, found.get.name)()
}.getOrElse(found.get) // Data type doesn't change. We should add fields at other side.
case _ =>
// Same struct type, or
// unsupported: different types, array or map types, or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array and map types aren't supported by WithFields. I think it is still possible to add them to WithFields. Once WithFields supports these types, we can add them here too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok.

// `allowMissingCol` is disabled.
found.get
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found.get -> foundAttr

}
} else {
if (allowMissingCol) {
Alias(Literal(null, lattr.dataType), lattr.name)()
} else {
Expand All @@ -50,18 +118,29 @@ object ResolveUnion extends Rule[LogicalPlan] {
}
}

(rightProjectList, aliased)
}
viirya marked this conversation as resolved.
Show resolved Hide resolved

private def unionTwoSides(
left: LogicalPlan,
right: LogicalPlan,
allowMissingCol: Boolean): LogicalPlan = {
val rightOutputAttrs = right.output

// Builds a project list for `right` based on `left` output names
val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol)

// Delegates failure checks to `CheckAnalysis`
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList)
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased)
val rightChild = Project(rightProjectList ++ notFoundAttrs, right)

// Builds a project for `logicalPlan` based on `right` output names, if allowing
// missing columns.
val leftChild = if (allowMissingCol) {
val missingAttrs = notFoundAttrs.map { attr =>
Alias(Literal(null, attr.dataType), attr.name)()
}
if (missingAttrs.nonEmpty) {
Project(leftOutputAttrs ++ missingAttrs, left)
// Add missing (nested) fields to left plan.
val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol)
if (leftProjectList.map(_.toAttribute) != left.output) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

      if (leftProjectList.length != left.output.length ||
          leftProjectList.map(_.toAttribute) != left.output) {

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't leftProjectList.map(_.toAttribute) != left.output already cover leftProjectList.length != left.output.length?

Project(leftProjectList, left)
} else {
left
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -546,7 +547,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression]) extends Unevaluable {
valExprs: Seq[Expression],
sortColumns: Boolean = false) extends Unevaluable {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about defining toString for not displaying this value in explain output?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sortColumns -> sortOutputColumns?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed sortColumns to sortOutputColumns. I'm not sure we want to hide sortColumns?

Copy link
Member

@maropu maropu Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this param is not related to the withField operation, so I left the comment. But, either is okay (Just a suggestion).

scala> val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
scala> df.select($"struct_col".withField("c", lit(3))).explain(true)
== Analyzed Logical Plan ==
with_fields(struct_col, 3): struct<a:int,b:int,c:int>
Project [with_fields(struct_col#0, c, 3, false) AS with_fields(struct_col, 3)#4]
                                         ^^^^^
+- Project [named_struct(a, 1, b, 2) AS struct_col#0]
   +- OneRowRelation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain if we want to show it or not. Let's keep it as is and see what others think.


assert(names.length == valExprs.length)

Expand Down Expand Up @@ -585,13 +587,71 @@ case class WithFields(
} else {
resultExprs :+ newExpr
}
}.flatMap { case (name, expr) => Seq(Literal(name), expr) }
}

val expr = CreateNamedStruct(newExprs)
val finalExprs = if (sortColumns) {
newExprs.sortBy(_._1).flatMap { case (name, expr) => Seq(Literal(name), expr) }
} else {
newExprs.flatMap { case (name, expr) => Seq(Literal(name), expr) }
}

val expr = CreateNamedStruct(finalExprs)
if (structExpr.nullable) {
If(IsNull(structExpr), Literal(null, expr.dataType), expr)
} else {
expr
}
}
}

object WithFields {
/**
* Adds/replaces field in `StructType` into `col` expression by name.
*/
def apply(col: Expression, fieldName: String, expr: Expression): Expression = {
WithFields(col, fieldName, expr, false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. false -> sortOutputColumns = false?

}

def apply(
col: Expression,
fieldName: String,
expr: Expression,
sortColumns: Boolean): Expression = {
val nameParts = if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
withFieldHelper(col, nameParts, Nil, expr, sortColumns)
}

private def withFieldHelper(
struct: Expression,
namePartsRemaining: Seq[String],
namePartsDone: Seq[String],
value: Expression,
sortColumns: Boolean) : WithFields = {
val name = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
WithFields(struct, name :: Nil, value :: Nil, sortColumns)
} else {
val newNamesRemaining = namePartsRemaining.tail
val newNamesDone = namePartsDone :+ name

val newStruct = if (struct.resolved) {
val resolver = SQLConf.get.resolver
ExtractValue(struct, Literal(name), resolver)
} else {
UnresolvedExtractValue(struct, Literal(name))
}

val newValue = withFieldHelper(
struct = newStruct,
namePartsRemaining = newNamesRemaining,
namePartsDone = newNamesDone,
value = value,
sortColumns = sortColumns)
WithFields(struct, name :: Nil, newValue :: Nil, sortColumns)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
createNamedStruct.valExprs(ordinal)
case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) =>
case GetStructField(w @ WithFields(struct, names, valExprs, _), ordinal, maybeName) =>
val name = w.dataType(ordinal).name
val matches = names.zip(valExprs).filter(_._1 == name)
if (matches.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
*/
object CombineWithFields extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) =>
case WithFields(WithFields(struct, names1, valExprs1, sort1), names2, valExprs2, sort2)
if sort1 == sort2 =>
WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,4 +641,30 @@ object StructType extends AbstractDataType {
fields.foreach(s => map.put(s.name, s))
map
}

/**
* Returns a `StructType` that contains missing fields recursively from `source` to `target`.
* Note that this doesn't support looking into array type and map type recursively.
Copy link
Member

@maropu maropu Aug 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this limitation come?; we don't need to support this case, or supporting it is technically difficult? Ah, I see. Is this an unsupported case, right?
https://github.com/apache/spark/pull/29587/files#diff-4d656d696512d6bcb03a48f7e0af6251R106-R107

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leverage WithFields to add missing nested fields into structs. WithFields doesn't support array or map types currently.

*/
def findMissingFields(source: StructType, target: StructType, resolver: Resolver): StructType = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to define this method in the StructType side instead of a private method in ResolveUnion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is more general method related to StructType. So putting it in StructType.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, one nit: a return type Option[StructType] for findXXX methods is more natural just like scala collection (e.g., Seq.find)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, it sounds good.

def bothStructType(dt1: DataType, dt2: DataType): Boolean =
dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType]

val newFields = mutable.ArrayBuffer.empty[StructField]

target.fields.foreach { field =>
val found = source.fields.find(f => resolver(field.name, f.name))
if (found.isEmpty) {
// Found a missing field in `source`.
newFields += field
} else if (bothStructType(found.get.dataType, field.dataType) &&
!found.get.dataType.sameType(field.dataType)) {
// Found a field with same name, but different data type.
newFields += found.get.copy(dataType =
findMissingFields(found.get.dataType.asInstanceOf[StructType],
field.dataType.asInstanceOf[StructType], resolver))
}
}
StructType(newFields)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.types

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType.fromDDL

class StructTypeSuite extends SparkFunSuite {
Expand Down Expand Up @@ -103,4 +104,30 @@ class StructTypeSuite extends SparkFunSuite {
val interval = "`a` INTERVAL"
assert(fromDDL(interval).toDDL === interval)
}

test("find missing (nested) fields") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test the behaviours of unsupported cases (array and map), too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind, could you split the test function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

val schema = StructType.fromDDL(
"c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
val resolver = SQLConf.get.resolver
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some tests for case-sensitivity?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, added.


val source1 = StructType.fromDDL("c1 INT")
val missing1 = StructType.fromDDL(
"c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source1, schema, resolver).sameType(missing1))

val source2 = StructType.fromDDL("c1 INT, c3 STRING")
val missing2 = StructType.fromDDL(
"c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source2, schema, resolver).sameType(missing2))

val source3 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT>")
val missing3 = StructType.fromDDL(
"c2 STRUCT<c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source3, schema, resolver).sameType(missing3))

val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c6: INT>>")
val missing4 = StructType.fromDDL(
"c2 STRUCT<c4: STRUCT<c5: INT>>")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we make this as a one-liner?

  val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c6: INT>>")
- val missing4 = StructType.fromDDL(
-   "c2 STRUCT<c4: STRUCT<c5: INT>>")
+ val missing4 = StructType.fromDDL("c2 STRUCT<c4: STRUCT<c5: INT>>")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sure.

assert(StructType.findMissingFields(source4, schema, resolver).sameType(missing4))
}
}
27 changes: 1 addition & 26 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -909,32 +909,7 @@ class Column(val expr: Expression) extends Logging {
require(fieldName != null, "fieldName cannot be null")
require(col != null, "col cannot be null")

val nameParts = if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
withFieldHelper(expr, nameParts, Nil, col.expr)
}

private def withFieldHelper(
struct: Expression,
namePartsRemaining: Seq[String],
namePartsDone: Seq[String],
value: Expression) : WithFields = {
val name = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
WithFields(struct, name :: Nil, value :: Nil)
} else {
val newNamesRemaining = namePartsRemaining.tail
val newNamesDone = namePartsDone :+ name
val newValue = withFieldHelper(
struct = UnresolvedExtractValue(struct, Literal(name)),
namePartsRemaining = newNamesRemaining,
namePartsDone = newNamesDone,
value = value)
WithFields(struct, name :: Nil, newValue :: Nil)
}
WithFields(expr, fieldName, col.expr)
}

/**
Expand Down
Loading