From 95e8fd44396a0843bd0be4722bb2c1724f084a91 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 30 Aug 2020 19:26:32 -0700 Subject: [PATCH] Make unionByName null-filling behavior work with struct columns. --- .../sql/catalyst/analysis/ResolveUnion.scala | 103 +++++++++++++++--- .../expressions/complexTypeCreator.scala | 68 +++++++++++- .../sql/catalyst/optimizer/ComplexTypes.scala | 2 +- .../sql/catalyst/optimizer/WithFields.scala | 3 +- .../apache/spark/sql/types/StructType.scala | 26 +++++ .../spark/sql/types/StructTypeSuite.scala | 27 +++++ .../scala/org/apache/spark/sql/Column.scala | 27 +---- .../sql/DataFrameSetOperationsSuite.scala | 67 ++++++++++++ 8 files changed, 277 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 693a5a4e75443..c0eace6052590 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -17,29 +17,97 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, Literal, NamedExpression, WithFields} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils /** * Resolves different children of Union to a common set of columns. */ object ResolveUnion extends Rule[LogicalPlan] { - private def unionTwoSides( + /** + * Adds missing fields recursively into given `col` expression, based on the target `StructType`. + * For example, given `col` as "a struct, b int" and `target` as + * "a struct, b int, c string", this method should add `a.c` and `c` to + * `col` expression. + */ + private def addFields(col: NamedExpression, target: StructType): Option[Expression] = { + require(col.dataType.isInstanceOf[StructType], "Only support StructType.") + + val resolver = SQLConf.get.resolver + val missingFields = + StructType.findMissingFields(col.dataType.asInstanceOf[StructType], target, resolver) + if (missingFields.length == 0) { + None + } else { + Some(addFieldsInto(col, "", missingFields.fields)) + } + } + + private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = { + var currCol = col + fields.foreach { field => + field.dataType match { + case dt: AtomicType => + // We need to sort columns in result, because we might add another column in other side. + // E.g., we want to union two structs "a int, b long" and "a int, c string". + // If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long", + // which are not compatible. + currCol = WithFields(currCol, s"$base${field.name}", Literal(null, dt), + sortColumns = true) + case st: StructType => + val resolver = SQLConf.get.resolver + val colField = currCol.dataType.asInstanceOf[StructType] + .find(f => resolver(f.name, field.name)) + if (colField.isEmpty) { + // The whole struct is missing. Add a null. + currCol = WithFields(currCol, s"$base${field.name}", Literal(null, st), + sortColumns = true) + } else { + currCol = addFieldsInto(currCol, s"$base${field.name}.", st.fields) + } + } + } + currCol + } + + private def compareAndAddFields( left: LogicalPlan, right: LogicalPlan, - allowMissingCol: Boolean): LogicalPlan = { + allowMissingCol: Boolean): (Seq[NamedExpression], Seq[NamedExpression]) = { val resolver = SQLConf.get.resolver val leftOutputAttrs = left.output val rightOutputAttrs = right.output - // Builds a project list for `right` based on `left` output names + val aliased = mutable.ArrayBuffer.empty[Attribute] + val rightProjectList = leftOutputAttrs.map { lattr => - rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse { + val found = rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) } + if (found.isDefined) { + val foundDt = found.get.dataType + (foundDt, lattr.dataType) match { + case (source: StructType, target: StructType) + if allowMissingCol && !source.sameType(target) => + // Having an output with same name, but different struct type. + // We need to add missing fields. + addFields(found.get, target).map { added => + aliased += found.get + Alias(added, found.get.name)() + }.getOrElse(found.get) // Data type doesn't change. We should add fields at other side. + case _ => + // Same struct type, or + // unsupported: different types, array or map types, or + // `allowMissingCol` is disabled. + found.get + } + } else { if (allowMissingCol) { Alias(Literal(null, lattr.dataType), lattr.name)() } else { @@ -50,21 +118,28 @@ object ResolveUnion extends Rule[LogicalPlan] { } } + (rightProjectList, aliased) + } + + private def unionTwoSides( + left: LogicalPlan, + right: LogicalPlan, + allowMissingCol: Boolean): LogicalPlan = { + val rightOutputAttrs = right.output + + // Builds a project list for `right` based on `left` output names + val (rightProjectList, aliased) = compareAndAddFields(left, right, allowMissingCol) + // Delegates failure checks to `CheckAnalysis` - val notFoundAttrs = rightOutputAttrs.diff(rightProjectList) + val notFoundAttrs = rightOutputAttrs.diff(rightProjectList ++ aliased) val rightChild = Project(rightProjectList ++ notFoundAttrs, right) // Builds a project for `logicalPlan` based on `right` output names, if allowing // missing columns. val leftChild = if (allowMissingCol) { - val missingAttrs = notFoundAttrs.map { attr => - Alias(Literal(null, attr.dataType), attr.name)() - } - if (missingAttrs.nonEmpty) { - Project(leftOutputAttrs ++ missingAttrs, left) - } else { - left - } + // Add missing (nested) fields to left plan. + val (leftProjectList, _) = compareAndAddFields(rightChild, left, allowMissingCol) + Project(leftProjectList, left) } else { left } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 563ce7133a3dc..87f229fa83e2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -546,7 +547,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E case class WithFields( structExpr: Expression, names: Seq[String], - valExprs: Seq[Expression]) extends Unevaluable { + valExprs: Seq[Expression], + sortColumns: Boolean = false) extends Unevaluable { assert(names.length == valExprs.length) @@ -585,9 +587,15 @@ case class WithFields( } else { resultExprs :+ newExpr } - }.flatMap { case (name, expr) => Seq(Literal(name), expr) } + } - val expr = CreateNamedStruct(newExprs) + val finalExprs = if (sortColumns) { + newExprs.sortBy(_._1).flatMap { case (name, expr) => Seq(Literal(name), expr) } + } else { + newExprs.flatMap { case (name, expr) => Seq(Literal(name), expr) } + } + + val expr = CreateNamedStruct(finalExprs) if (structExpr.nullable) { If(IsNull(structExpr), Literal(null, expr.dataType), expr) } else { @@ -595,3 +603,55 @@ case class WithFields( } } } + +object WithFields { + /** + * Adds/replaces field in `StructType` into `col` expression by name. + */ + def apply(col: Expression, fieldName: String, expr: Expression): Expression = { + WithFields(col, fieldName, expr, false) + } + + def apply( + col: Expression, + fieldName: String, + expr: Expression, + sortColumns: Boolean): Expression = { + val nameParts = if (fieldName.isEmpty) { + fieldName :: Nil + } else { + CatalystSqlParser.parseMultipartIdentifier(fieldName) + } + withFieldHelper(col, nameParts, Nil, expr, sortColumns) + } + + private def withFieldHelper( + struct: Expression, + namePartsRemaining: Seq[String], + namePartsDone: Seq[String], + value: Expression, + sortColumns: Boolean) : WithFields = { + val name = namePartsRemaining.head + if (namePartsRemaining.length == 1) { + WithFields(struct, name :: Nil, value :: Nil, sortColumns) + } else { + val newNamesRemaining = namePartsRemaining.tail + val newNamesDone = namePartsDone :+ name + + val newStruct = if (struct.resolved) { + val resolver = SQLConf.get.resolver + ExtractValue(struct, Literal(name), resolver) + } else { + UnresolvedExtractValue(struct, Literal(name)) + } + + val newValue = withFieldHelper( + struct = newStruct, + namePartsRemaining = newNamesRemaining, + namePartsDone = newNamesDone, + value = value, + sortColumns = sortColumns) + WithFields(struct, name :: Nil, newValue :: Nil, sortColumns) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 1c33a2c7c3136..87005cbde11de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -39,7 +39,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => + case GetStructField(w @ WithFields(struct, names, valExprs, _), ordinal, maybeName) => val name = w.dataType(ordinal).name val matches = names.zip(valExprs).filter(_._1 == name) if (matches.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala index 05c90864e4bb0..44572a9c46d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.rules.Rule */ object CombineWithFields extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => + case WithFields(WithFields(struct, names1, valExprs1, sort1), names2, valExprs2, sort2) + if sort1 == sort2 => WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b14fb04cc4539..fa97fe233e9ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -641,4 +641,30 @@ object StructType extends AbstractDataType { fields.foreach(s => map.put(s.name, s)) map } + + /** + * Returns a `StructType` that contains missing fields recursively from `source` to `target`. + * Note that this doesn't support looking into array type and map type recursively. + */ + def findMissingFields(source: StructType, target: StructType, resolver: Resolver): StructType = { + def bothStructType(dt1: DataType, dt2: DataType): Boolean = + dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType] + + val newFields = mutable.ArrayBuffer.empty[StructField] + + target.fields.foreach { field => + val found = source.fields.find(f => resolver(field.name, f.name)) + if (found.isEmpty) { + // Found a missing field in `source`. + newFields += field + } else if (bothStructType(found.get.dataType, field.dataType) && + !found.get.dataType.sameType(field.dataType)) { + // Found a field with same name, but different data type. + newFields += found.get.copy(dataType = + findMissingFields(found.get.dataType.asInstanceOf[StructType], + field.dataType.asInstanceOf[StructType], resolver)) + } + } + StructType(newFields) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 6824a64badc10..3f5bf56662f99 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType.fromDDL class StructTypeSuite extends SparkFunSuite { @@ -103,4 +104,30 @@ class StructTypeSuite extends SparkFunSuite { val interval = "`a` INTERVAL" assert(fromDDL(interval).toDDL === interval) } + + test("find missing (nested) fields") { + val schema = StructType.fromDDL( + "c1 INT, c2 STRUCT>") + val resolver = SQLConf.get.resolver + + val source1 = StructType.fromDDL("c1 INT") + val missing1 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source1, schema, resolver).sameType(missing1)) + + val source2 = StructType.fromDDL("c1 INT, c3 STRING") + val missing2 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source2, schema, resolver).sameType(missing2)) + + val source3 = StructType.fromDDL("c1 INT, c2 STRUCT") + val missing3 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source3, schema, resolver).sameType(missing3)) + + val source4 = StructType.fromDDL("c1 INT, c2 STRUCT>") + val missing4 = StructType.fromDDL( + "c2 STRUCT>") + assert(StructType.findMissingFields(source4, schema, resolver).sameType(missing4)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index da542c67d9c51..dabcd905587fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -909,32 +909,7 @@ class Column(val expr: Expression) extends Logging { require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") - val nameParts = if (fieldName.isEmpty) { - fieldName :: Nil - } else { - CatalystSqlParser.parseMultipartIdentifier(fieldName) - } - withFieldHelper(expr, nameParts, Nil, col.expr) - } - - private def withFieldHelper( - struct: Expression, - namePartsRemaining: Seq[String], - namePartsDone: Seq[String], - value: Expression) : WithFields = { - val name = namePartsRemaining.head - if (namePartsRemaining.length == 1) { - WithFields(struct, name :: Nil, value :: Nil) - } else { - val newNamesRemaining = namePartsRemaining.tail - val newNamesDone = namePartsDone :+ name - val newValue = withFieldHelper( - struct = UnresolvedExtractValue(struct, Literal(name)), - namePartsRemaining = newNamesRemaining, - namePartsDone = newNamesDone, - value = value) - WithFields(struct, name :: Nil, newValue :: Nil) - } + WithFields(expr, fieldName, col.expr) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index e72b8ce860b28..72b387b7fdff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -536,4 +536,71 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { assert(union2.schema.fieldNames === Array("a", "B", "C", "c")) } } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - 1") { + val df1 = Seq(((1, 2, 3), 0), ((2, 3, 4), 1), ((3, 4, 5), 2)).toDF("a", "idx") + val df2 = Seq(((3, 4), 0), ((1, 2), 1), ((2, 3), 2)).toDF("a", "idx") + val df3 = Seq(((100, 101, 102, 103), 0), ((110, 111, 112, 113), 1), ((120, 121, 122, 123), 2)) + .toDF("a", "idx") + + var unionDf = df1.unionByName(df2, true) + + checkAnswer(unionDf, + Row(Row(1, 2, 3), 0) :: Row(Row(2, 3, 4), 1) :: Row(Row(3, 4, 5), 2) :: // df1 + Row(Row(3, 4, null), 0) :: Row(Row(1, 2, null), 1) :: Row(Row(2, 3, null), 2) :: Nil // df2 + ) + + assert(unionDf.schema.toDDL == "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT>,`idx` INT") + + unionDf = df1.unionByName(df2, true).unionByName(df3, true) + + checkAnswer(unionDf, + Row(Row(1, 2, 3, null), 0) :: + Row(Row(2, 3, 4, null), 1) :: + Row(Row(3, 4, 5, null), 2) :: // df1 + Row(Row(3, 4, null, null), 0) :: + Row(Row(1, 2, null, null), 1) :: + Row(Row(2, 3, null, null), 2) :: // df2 + Row(Row(100, 101, 102, 103), 0) :: + Row(Row(110, 111, 112, 113), 1) :: + Row(Row(120, 121, 122, 123), 2) :: Nil // df3 + ) + assert(unionDf.schema.toDDL == + "`a` STRUCT<`_1`: INT, `_2`: INT, `_3`: INT, `_4`: INT>,`idx` INT") + } + + test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - 2") { + val df1 = Seq((0, UnionClass1a(0, 1L, UnionClass2(1, "2")))).toDF("id", "a") + val df2 = Seq((1, UnionClass1b(1, 2L, UnionClass3(2, 3L)))).toDF("id", "a") + + var unionDf = df1.unionByName(df2, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(1, Row(1, 2, Row(2, 3L, null))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + + unionDf = df2.unionByName(df1, true) + checkAnswer(unionDf, + Row(1, Row(1, 2, Row(2, 3L, null))) :: + Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + + val df3 = Seq((2, UnionClass1b(2, 3L, null))).toDF("id", "a") + unionDf = df1.unionByName(df3, true) + checkAnswer(unionDf, + Row(0, Row(0, 1, Row(1, null, "2"))) :: + Row(2, Row(2, 3, null)) :: Nil) + assert(unionDf.schema.toDDL == + "`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " + + "`nested`: STRUCT<`a`: INT, `b`: BIGINT, `c`: STRING>>") + } } + +case class UnionClass1a(a: Int, b: Long, nested: UnionClass2) +case class UnionClass1b(a: Int, b: Long, nested: UnionClass3) +case class UnionClass2(a: Int, c: String) +case class UnionClass3(a: Int, b: Long)