Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Sep 1, 2020
1 parent 5db1e0f commit 8bec8a3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object ResolveUnion extends Rule[LogicalPlan] {
* `col` expression.
*/
private def addFields(col: NamedExpression, target: StructType): Option[Expression] = {
require(col.dataType.isInstanceOf[StructType], "Only support StructType.")
assert(col.dataType.isInstanceOf[StructType], "Only support StructType.")

val resolver = SQLConf.get.resolver
val missingFields =
Expand All @@ -52,30 +52,28 @@ object ResolveUnion extends Rule[LogicalPlan] {
}

private def addFieldsInto(col: Expression, base: String, fields: Seq[StructField]): Expression = {
var currCol = col
fields.foreach { field =>
fields.foldLeft(col) { case (currCol, field) =>
field.dataType match {
case dt: AtomicType =>
// We need to sort columns in result, because we might add another column in other side.
// E.g., we want to union two structs "a int, b long" and "a int, c string".
// If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long",
// which are not compatible.
currCol = WithFields(currCol, s"$base${field.name}", Literal(null, dt),
sortColumns = true)
case st: StructType =>
val resolver = SQLConf.get.resolver
val colField = currCol.dataType.asInstanceOf[StructType]
.find(f => resolver(f.name, field.name))
if (colField.isEmpty) {
// The whole struct is missing. Add a null.
currCol = WithFields(currCol, s"$base${field.name}", Literal(null, st),
sortColumns = true)
WithFields(currCol, s"$base${field.name}", Literal(null, st),
sortOutputColumns = true)
} else {
currCol = addFieldsInto(currCol, s"$base${field.name}.", st.fields)
addFieldsInto(currCol, s"$base${field.name}.", st.fields)
}
case dt =>
// We need to sort columns in result, because we might add another column in other side.
// E.g., we want to union two structs "a int, b long" and "a int, c string".
// If we don't sort, we will have "a int, b long, c string" and "a int, c string, b long",
// which are not compatible.
WithFields(currCol, s"$base${field.name}", Literal(null, dt),
sortOutputColumns = true)
}
}
currCol
}

private def compareAndAddFields(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ case class WithFields(
structExpr: Expression,
names: Seq[String],
valExprs: Seq[Expression],
sortColumns: Boolean = false) extends Unevaluable {
sortOutputColumns: Boolean = false) extends Unevaluable {

assert(names.length == valExprs.length)

Expand Down Expand Up @@ -589,7 +589,7 @@ case class WithFields(
}
}

val finalExprs = if (sortColumns) {
val finalExprs = if (sortOutputColumns) {
newExprs.sortBy(_._1).flatMap { case (name, expr) => Seq(Literal(name), expr) }
} else {
newExprs.flatMap { case (name, expr) => Seq(Literal(name), expr) }
Expand All @@ -616,24 +616,24 @@ object WithFields {
col: Expression,
fieldName: String,
expr: Expression,
sortColumns: Boolean): Expression = {
sortOutputColumns: Boolean): Expression = {
val nameParts = if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
withFieldHelper(col, nameParts, Nil, expr, sortColumns)
withFieldHelper(col, nameParts, Nil, expr, sortOutputColumns)
}

private def withFieldHelper(
struct: Expression,
namePartsRemaining: Seq[String],
namePartsDone: Seq[String],
value: Expression,
sortColumns: Boolean) : WithFields = {
sortOutputColumns: Boolean) : WithFields = {
val name = namePartsRemaining.head
if (namePartsRemaining.length == 1) {
WithFields(struct, name :: Nil, value :: Nil, sortColumns)
WithFields(struct, name :: Nil, value :: Nil, sortOutputColumns)
} else {
val newNamesRemaining = namePartsRemaining.tail
val newNamesDone = namePartsDone :+ name
Expand All @@ -650,8 +650,8 @@ object WithFields {
namePartsRemaining = newNamesRemaining,
namePartsDone = newNamesDone,
value = value,
sortColumns = sortColumns)
WithFields(struct, name :: Nil, newValue :: Nil, sortColumns)
sortOutputColumns = sortOutputColumns)
WithFields(struct, name :: Nil, newValue :: Nil, sortOutputColumns)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,45 @@ class StructTypeSuite extends SparkFunSuite {
val missing4 = StructType.fromDDL(
"c2 STRUCT<c4: STRUCT<c5: INT>>")
assert(StructType.findMissingFields(source4, schema, resolver).sameType(missing4))

val schemaWithArray = StructType.fromDDL(
"c1 INT, c2 ARRAY<STRUCT<c3: INT, c4: LONG>>")
val source5 = StructType.fromDDL(
"c1 INT")
val missing5 = StructType.fromDDL(
"c2 ARRAY<STRUCT<c3: INT, c4: LONG>>")
assert(StructType.findMissingFields(source5, schemaWithArray, resolver).sameType(missing5))

val schemaWithMap1 = StructType.fromDDL(
"c1 INT, c2 MAP<STRUCT<c3: INT, c4: LONG>, STRING>, c3 LONG")
val source6 = StructType.fromDDL(
"c1 INT, c3 LONG")
val missing6 = StructType.fromDDL(
"c2 MAP<STRUCT<c3: INT, c4: LONG>, STRING>")
assert(StructType.findMissingFields(source6, schemaWithMap1, resolver).sameType(missing6))

val schemaWithMap2 = StructType.fromDDL(
"c1 INT, c2 MAP<STRING, STRUCT<c3: INT, c4: LONG>>, c3 STRING")
val source7 = StructType.fromDDL(
"c1 INT, c3 STRING")
val missing7 = StructType.fromDDL(
"c2 MAP<STRING, STRUCT<c3: INT, c4: LONG>>")
assert(StructType.findMissingFields(source7, schemaWithMap2, resolver).sameType(missing7))

// Unsupported: nested struct in array, map
val source8 = StructType.fromDDL(
"c1 INT, c2 ARRAY<STRUCT<c3: INT>>")
// `findMissingFields` doesn't support looking into nested struct in array type.
assert(StructType.findMissingFields(source8, schemaWithArray, resolver).length == 0)

val source9 = StructType.fromDDL(
"c1 INT, c2 MAP<STRUCT<c3: INT>, STRING>, c3 LONG")
// `findMissingFields` doesn't support looking into nested struct in map type.
assert(StructType.findMissingFields(source9, schemaWithMap1, resolver).length == 0)

val source10 = StructType.fromDDL(
"c1 INT, c2 MAP<STRING, STRUCT<c3: INT>>, c3 STRING")
// `findMissingFields` doesn't support looking into nested struct in map type.
assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).length == 0)
}
}

0 comments on commit 8bec8a3

Please sign in to comment.