-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-32376][SQL] Make unionByName null-filling behavior work with struct columns #29587
Changes from 8 commits
95e8fd4
5db1e0f
8bec8a3
2515d78
4398e77
c787f66
3ea24af
ae14447
72800e6
337cea7
90fc4fc
7b0d65d
a77481e
b4270f4
61ff46f
9040c56
9b21d91
1829889
d16bf7d
8a9522e
bb8938f
9e73928
c07e30f
2ca1379
3d907d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,29 +17,107 @@ | |
|
||
package org.apache.spark.sql.catalyst.analysis | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.sql.AnalysisException | ||
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} | ||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, Literal, NamedExpression, WithFields} | ||
import org.apache.spark.sql.catalyst.optimizer.CombineUnions | ||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.util.SchemaUtils | ||
|
||
/** | ||
* Resolves different children of Union to a common set of columns. | ||
*/ | ||
object ResolveUnion extends Rule[LogicalPlan] { | ||
private def unionTwoSides( | ||
/** | ||
* Adds missing fields recursively into given `col` expression, based on the target `StructType`. | ||
* For example, given `col` as "a struct<a:int, b:int>, b int" and `target` as | ||
* "a struct<a:int, b:int, c:long>, b int, c string", this method should add `a.c` and `c` to | ||
* `col` expression. | ||
*/ | ||
private def addFields(col: NamedExpression, target: StructType): Option[Expression] = { | ||
assert(col.dataType.isInstanceOf[StructType], "Only support StructType.") | ||
|
||
val resolver = SQLConf.get.resolver | ||
val missingFields = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name is a bit misleading and I though it's a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good. Fixed. |
||
StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) | ||
if (missingFields.isEmpty) { | ||
None | ||
} else { | ||
missingFields.map(s => addFieldsInto(col, "", s.fields)) | ||
} | ||
} | ||
|
||
/** | ||
* Adds missing fields recursively into given `col` expression. The missing fields are given | ||
* in `fields`. For example, given `col` as "a struct<a:int, b:int>, b int", and `fields` is | ||
* "a struct<c:long>, c string". This method will add a nested `a.c` field and a top-level | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have a test case for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are end-to-end tests for that. I will update this comment with the example. |
||
* `c` field to `col` and fill null values for them. | ||
*/ | ||
private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = { | ||
fields.foldLeft(col) { case (currCol, field) => | ||
field.dataType match { | ||
case st: StructType => | ||
val resolver = SQLConf.get.resolver | ||
val colField = currCol.dataType.asInstanceOf[StructType] | ||
.find(f => resolver(f.name, field.name)) | ||
if (colField.isEmpty) { | ||
// The whole struct is missing. Add a null. | ||
WithFields(currCol, s"$base${field.name}", Literal(null, st), | ||
sortOutputColumns = true) | ||
} else { | ||
addFieldsInto(currCol, s"$base${field.name}.", st.fields) | ||
} | ||
case dt => | ||
// We need to sort columns in result, because we might add another column in other side. | ||
// E.g., we want to union two structs "a int, b long" and "a int, c string". | ||
// If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long", | ||
// which are not compatible. | ||
WithFields(currCol, s"$base${field.name}", Literal(null, dt), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a question. What happens if there exist 4000 missing columns? Any performance issue with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I think it is bad for performance using However, I think we currently have no better approach to fill missing (nested) fields in structs. We might add a note to the |
||
sortOutputColumns = true) | ||
} | ||
} | ||
} | ||
|
||
private def compareAndAddFields( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although we have a rich comment in the function body, could you add a function description to give a general idea? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
left: LogicalPlan, | ||
right: LogicalPlan, | ||
allowMissingCol: Boolean): LogicalPlan = { | ||
allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = { | ||
val resolver = SQLConf.get.resolver | ||
val leftOutputAttrs = left.output | ||
val rightOutputAttrs = right.output | ||
|
||
// Builds a project list for `right` based on `left` output names | ||
val aliased = mutable.ArrayBuffer.empty[Attribute] | ||
|
||
val rightProjectList = leftOutputAttrs.map { lattr => | ||
rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { | ||
val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) } | ||
if (found.isDefined) { | ||
dongjoon-hyun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val foundAttr = found.get | ||
val foundDt = foundAttr.dataType | ||
(foundDt, lattr.dataType) match { | ||
case (source: StructType, target: StructType) | ||
if allowMissingCol && !source.sameType(target) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make the logic simpler, could we filter out all the unsupported case (e.g.,
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I'm not sure where we can simplify the logic? By adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I read the comment and I thought first that all the unsupported cases are handled in the line 108-112. But, it also means unsupported cases if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I will add more comments explaining this. |
||
// Having an output with same name, but different struct type. | ||
// We need to add missing fields. Note that if there are deeply nested structs such as | ||
// nested struct of array in struct, we don't support to add missing deeply nested field | ||
// like that. For such case, simply use original attribute. | ||
addFields(foundAttr, target).map { added => | ||
aliased += foundAttr | ||
Alias(added, foundAttr.name)() | ||
}.getOrElse(foundAttr) | ||
case _ => | ||
// We don't need/try to add missing fields if: | ||
// 1. The attributes of left and right side are the same struct type | ||
// 2. The attributes are not struct types. They might be primitive types, or array, map | ||
// types. We don't support adding missing fields of nested structs in array or map | ||
// types now. | ||
// 3. `allowMissingCol` is disabled. | ||
found.get | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
} else { | ||
if (allowMissingCol) { | ||
Alias(Literal(null, lattr.dataType), lattr.name)() | ||
} else { | ||
|
@@ -50,18 +128,29 @@ object ResolveUnion extends Rule[LogicalPlan] { | |
} | ||
} | ||
|
||
(rightProjectList, aliased) | ||
} | ||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
private def unionTwoSides( | ||
left: LogicalPlan, | ||
right: LogicalPlan, | ||
allowMissingCol: Boolean): LogicalPlan = { | ||
val rightOutputAttrs = right.output | ||
|
||
// Builds a project list for `right` based on `left` output names | ||
val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) | ||
|
||
// Delegates failure checks to `CheckAnalysis` | ||
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) | ||
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) | ||
val rightChild = Project(rightProjectList ++ notFoundAttrs, right) | ||
|
||
// Builds a project for `logicalPlan` based on `right` output names, if allowing | ||
// missing columns. | ||
val leftChild = if (allowMissingCol) { | ||
val missingAttrs = notFoundAttrs.map { attr => | ||
Alias(Literal(null, attr.dataType), attr.name)() | ||
} | ||
if (missingAttrs.nonEmpty) { | ||
Project(leftOutputAttrs ++ missingAttrs, left) | ||
// Add missing (nested) fields to left plan. | ||
val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol) | ||
if (leftProjectList.map(_.toAttribute) != left.output) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't |
||
Project(leftProjectList, left) | ||
} else { | ||
left | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,10 +18,11 @@ | |
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} | ||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedExtractValue} | ||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} | ||
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser | ||
import org.apache.spark.sql.catalyst.util._ | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types._ | ||
|
@@ -546,7 +547,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E | |
case class WithFields( | ||
structExpr: Expression, | ||
names: Seq[String], | ||
valExprs: Seq[Expression]) extends Unevaluable { | ||
valExprs: Seq[Expression], | ||
sortOutputColumns: Boolean = false) extends Unevaluable { | ||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
assert(names.length == valExprs.length) | ||
|
||
|
@@ -585,13 +587,71 @@ case class WithFields( | |
} else { | ||
resultExprs :+ newExpr | ||
} | ||
}.flatMap { case (name, expr) => Seq(Literal(name), expr) } | ||
} | ||
|
||
val expr = CreateNamedStruct(newExprs) | ||
val finalExprs = if (sortOutputColumns) { | ||
newExprs.sortBy(_._1).flatMap { case (name, expr) => Seq(Literal(name), expr) } | ||
} else { | ||
newExprs.flatMap { case (name, expr) => Seq(Literal(name), expr) } | ||
} | ||
|
||
val expr = CreateNamedStruct(finalExprs) | ||
if (structExpr.nullable) { | ||
If(IsNull(structExpr), Literal(null, expr.dataType), expr) | ||
} else { | ||
expr | ||
} | ||
} | ||
} | ||
|
||
object WithFields { | ||
/** | ||
* Adds/replaces field in `StructType` into `col` expression by name. | ||
*/ | ||
def apply(col: Expression, fieldName: String, expr: Expression): Expression = { | ||
WithFields(col, fieldName, expr, false) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit. |
||
} | ||
|
||
def apply( | ||
col: Expression, | ||
fieldName: String, | ||
expr: Expression, | ||
sortOutputColumns: Boolean): Expression = { | ||
val nameParts = if (fieldName.isEmpty) { | ||
fieldName :: Nil | ||
} else { | ||
CatalystSqlParser.parseMultipartIdentifier(fieldName) | ||
} | ||
withFieldHelper(col, nameParts, Nil, expr, sortOutputColumns) | ||
} | ||
|
||
private def withFieldHelper( | ||
struct: Expression, | ||
namePartsRemaining: Seq[String], | ||
namePartsDone: Seq[String], | ||
value: Expression, | ||
sortOutputColumns: Boolean) : WithFields = { | ||
val name = namePartsRemaining.head | ||
if (namePartsRemaining.length == 1) { | ||
WithFields(struct, name :: Nil, value :: Nil, sortOutputColumns) | ||
} else { | ||
val newNamesRemaining = namePartsRemaining.tail | ||
val newNamesDone = namePartsDone :+ name | ||
|
||
val newStruct = if (struct.resolved) { | ||
val resolver = SQLConf.get.resolver | ||
ExtractValue(struct, Literal(name), resolver) | ||
} else { | ||
UnresolvedExtractValue(struct, Literal(name)) | ||
} | ||
|
||
val newValue = withFieldHelper( | ||
struct = newStruct, | ||
namePartsRemaining = newNamesRemaining, | ||
namePartsDone = newNamesDone, | ||
value = value, | ||
sortOutputColumns = sortOutputColumns) | ||
WithFields(struct, name :: Nil, newValue :: Nil, sortOutputColumns) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -641,4 +641,39 @@ object StructType extends AbstractDataType { | |
fields.foreach(s => map.put(s.name, s)) | ||
map | ||
} | ||
|
||
/** | ||
* Returns a `StructType` that contains missing fields recursively from `source` to `target`. | ||
* Note that this doesn't support looking into array type and map type recursively. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I leverage |
||
*/ | ||
def findMissingFields( | ||
source: StructType, | ||
target: StructType, | ||
resolver: Resolver): Option[StructType] = { | ||
def bothStructType(dt1: DataType, dt2: DataType): Boolean = | ||
dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType] | ||
|
||
val newFields = mutable.ArrayBuffer.empty[StructField] | ||
|
||
target.fields.foreach { field => | ||
val found = source.fields.find(f => resolver(field.name, f.name)) | ||
if (found.isEmpty) { | ||
// Found a missing field in `source`. | ||
newFields += field | ||
} else if (bothStructType(found.get.dataType, field.dataType) && | ||
!found.get.dataType.sameType(field.dataType)) { | ||
// Found a field with same name, but different data type. | ||
findMissingFields(found.get.dataType.asInstanceOf[StructType], | ||
field.dataType.asInstanceOf[StructType], resolver).map { missingType => | ||
newFields += found.get.copy(dataType = missingType) | ||
} | ||
} | ||
} | ||
|
||
if (newFields.isEmpty) { | ||
None | ||
} else { | ||
Some(StructType(newFields)) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you describe the negative cases? For instance, could you mention
compareAndAddFields
briefly and what happens when we meet the same column names with different types?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised the comment.