Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21043][SQL] Add unionByName in Dataset #18300

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 60 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Expand Up @@ -52,6 +52,7 @@ import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -1733,6 +1734,65 @@ class Dataset[T] private[sql](
CombineUnions(Union(logicalPlan, other.logicalPlan))
}

/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
*
* This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set
* union (that does deduplication of elements), use this function followed by a [[distinct]].
*
* The difference between this function and [[union]] is that this function
* resolves columns by name (not by position):
*
* {{{
* val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
* val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0")
* df1.unionByName(df2).show
*
* // output:
* // +----+----+----+
* // |col0|col1|col2|
* // +----+----+----+
* // | 1| 2| 3|
* // | 6| 4| 5|
* // +----+----+----+
* }}}
*
* @group typedrel
* @since 2.3.0
*/
def unionByName(other: Dataset[T]): Dataset[T] = withSetOperator {
// Check column name duplication
val resolver = sparkSession.sessionState.analyzer.resolver
val leftOutputAttrs = logicalPlan.output
val rightOutputAttrs = other.logicalPlan.output

SchemaUtils.checkColumnNameDuplication(
leftOutputAttrs.map(_.name),
"in the left attributes",
sparkSession.sessionState.conf.caseSensitiveAnalysis)
SchemaUtils.checkColumnNameDuplication(
rightOutputAttrs.map(_.name),
"in the right attributes",
sparkSession.sessionState.conf.caseSensitiveAnalysis)

// Builds a project list for `other` based on `logicalPlan` output names
val rightProjectList = leftOutputAttrs.map { lattr =>
rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "${lattr.name}" among """ +
s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""")
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile How about this impl.?


// Delegates failure checks to `CheckAnalysis`
val notFoundAttrs = rightOutputAttrs.diff(rightProjectList)
val rightChild = Project(rightProjectList ++ notFoundAttrs, other.logicalPlan)

// This breaks caching, but it's usually ok because it addresses a very specific use case:
// using union to union many files or partitions.
CombineUnions(Union(logicalPlan, rightChild))
}

/**
* Returns a new Dataset containing rows only in both this Dataset and another Dataset.
* This is equivalent to `INTERSECT` in SQL.
Expand Down
87 changes: 87 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Expand Up @@ -111,6 +111,93 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
)
}

test("union by name") {
var df1 = Seq((1, 2, 3)).toDF("a", "b", "c")
var df2 = Seq((3, 1, 2)).toDF("c", "a", "b")
val df3 = Seq((2, 3, 1)).toDF("b", "c", "a")
val unionDf = df1.unionByName(df2.unionByName(df3))
checkAnswer(unionDf,
Row(1, 2, 3) :: Row(1, 2, 3) :: Row(1, 2, 3) :: Nil
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @maropu .
To be clearer, could you add more test cases requiring type coercions here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, sure.


// Check if adjacent unions are combined into a single one
assert(unionDf.queryExecution.optimizedPlan.collect { case u: Union => true }.size == 1)

// Check failure cases
df1 = Seq((1, 2)).toDF("a", "c")
df2 = Seq((3, 4, 5)).toDF("a", "b", "c")
var errMsg = intercept[AnalysisException] {
df1.unionByName(df2)
}.getMessage
assert(errMsg.contains(
"Union can only be performed on tables with the same number of columns, " +
"but the first table has 2 columns and the second table has 3 columns"))

df1 = Seq((1, 2, 3)).toDF("a", "b", "c")
df2 = Seq((4, 5, 6)).toDF("a", "c", "d")
errMsg = intercept[AnalysisException] {
df1.unionByName(df2)
}.getMessage
assert(errMsg.contains("""Cannot resolve column name "b" among (a, c, d)"""))
}

test("union by name - type coercion") {
var df1 = Seq((1, "a")).toDF("c0", "c1")
var df2 = Seq((3, 1L)).toDF("c1", "c0")
checkAnswer(df1.unionByName(df2), Row(1L, "a") :: Row(1L, "3") :: Nil)

df1 = Seq((1, 1.0)).toDF("c0", "c1")
df2 = Seq((8L, 3.0)).toDF("c1", "c0")
checkAnswer(df1.unionByName(df2), Row(1.0, 1.0) :: Row(3.0, 8.0) :: Nil)

df1 = Seq((2.0f, 7.4)).toDF("c0", "c1")
df2 = Seq(("a", 4.0)).toDF("c1", "c0")
checkAnswer(df1.unionByName(df2), Row(2.0, "7.4") :: Row(4.0, "a") :: Nil)

df1 = Seq((1, "a", 3.0)).toDF("c0", "c1", "c2")
df2 = Seq((1.2, 2, "bc")).toDF("c2", "c0", "c1")
val df3 = Seq(("def", 1.2, 3)).toDF("c1", "c2", "c0")
checkAnswer(df1.unionByName(df2.unionByName(df3)),
Row(1, "a", 3.0) :: Row(2, "bc", 1.2) :: Row(3, "def", 1.2) :: Nil
)
}

test("union by name - check case sensitivity") {
def checkCaseSensitiveTest(): Unit = {
val df1 = Seq((1, 2, 3)).toDF("ab", "cd", "ef")
val df2 = Seq((4, 5, 6)).toDF("cd", "ef", "AB")
checkAnswer(df1.unionByName(df2), Row(1, 2, 3) :: Row(6, 4, 5) :: Nil)
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val errMsg2 = intercept[AnalysisException] {
checkCaseSensitiveTest()
}.getMessage
assert(errMsg2.contains("""Cannot resolve column name "ab" among (cd, ef, AB)"""))
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
checkCaseSensitiveTest()
}
}

test("union by name - check name duplication") {
Seq((true, ("a", "a")), (false, ("aA", "Aa"))).foreach { case (caseSensitive, (c0, c1)) =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) {
var df1 = Seq((1, 1)).toDF(c0, c1)
var df2 = Seq((1, 1)).toDF("c0", "c1")
var errMsg = intercept[AnalysisException] {
df1.unionByName(df2)
}.getMessage
assert(errMsg.contains("Found duplicate column(s) in the left attributes:"))
df1 = Seq((1, 1)).toDF("c0", "c1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: indents.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

df2 = Seq((1, 1)).toDF(c0, c1)
errMsg = intercept[AnalysisException] {
df1.unionByName(df2)
}.getMessage
assert(errMsg.contains("Found duplicate column(s) in the right attributes:"))
}
}
}

test("empty data frame") {
assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String])
assert(spark.emptyDataFrame.count() === 0)
Expand Down