Skip to content
Permalink
Browse files

[SPARK-29427][SQL] Add API to convert RelationalGroupedDataset to Key…

…ValueGroupedDataset

### What changes were proposed in this pull request?

This PR proposes to add `as` API to RelationalGroupedDataset. It creates KeyValueGroupedDataset instance using given grouping expressions, instead of a typed function in groupByKey API. Because it can leverage existing columns, it can use existing data partition, if any, when doing operations like cogroup.

### Why are the changes needed?

Currently if users want to do cogroup on DataFrames, there is no good way to do except for KeyValueGroupedDataset.

1. KeyValueGroupedDataset ignores existing data partition if any. That is a problem.
2. groupByKey calls typed function to create additional keys. You can not reuse existing columns, if you just need grouping by them.

```scala
// df1 and df2 are certainly partitioned and sorted.
val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c")
  .repartition($"a").sortWithinPartitions("a")
val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c")
  .repartition($"a").sortWithinPartitions("a")
```
```scala
// This groupBy.as.cogroup won't unnecessarily repartition the data
val df3 = df1.groupBy("a").as[Int]
  .cogroup(df2.groupBy("a").as[Int]) { case (key, data1, data2) =>
    data1.zip(data2).map { p =>
      p._1.getInt(2) + p._2.getInt(2)
    }
}
```

```
== Physical Plan ==
*(5) SerializeFromObject [input[0, int, false] AS value#11247]
+- CoGroup org.apache.spark.sql.DataFrameSuite$$Lambda$4922/12067092816eec1b6f, a#11209: int, createexternalrow(a#11209, b#11210, c#11211, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), createexternalrow(a#11225, b#11226, c#11227, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), [a#11209], [a#11225], [a#11209, b#11210, c#11211], [a#11225, b#11226, c#11227], obj#11246: int
   :- *(2) Sort [a#11209 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(a#11209, 5), false, [id=#10218]
   :     +- *(1) Project [_1#11202 AS a#11209, _2#11203 AS b#11210, _3#11204 AS c#11211]
   :        +- *(1) LocalTableScan [_1#11202, _2#11203, _3#11204]
   +- *(4) Sort [a#11225 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(a#11225, 5), false, [id=#10223]
         +- *(3) Project [_1#11218 AS a#11225, _2#11219 AS b#11226, _3#11220 AS c#11227]
            +- *(3) LocalTableScan [_1#11218, _2#11219, _3#11220]
```

```scala
// Current approach creates additional AppendColumns and repartition data again
val df4 = df1.groupByKey(r => r.getInt(0)).cogroup(df2.groupByKey(r => r.getInt(0))) {
  case (key, data1, data2) =>
    data1.zip(data2).map { p =>
      p._1.getInt(2) + p._2.getInt(2)
  }
}
```

```
== Physical Plan ==
*(7) SerializeFromObject [input[0, int, false] AS value#11257]
+- CoGroup org.apache.spark.sql.DataFrameSuite$$Lambda$4933/138102700737171997, value#11252: int, createexternalrow(a#11209, b#11210, c#11211, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), createexternalrow(a#11225, b#11226, c#11227, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), [value#11252], [value#11254], [a#11209, b#11210, c#11211], [a#11225, b#11226, c#11227], obj#11256: int
   :- *(3) Sort [value#11252 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(value#11252, 5), true, [id=#10302]
   :     +- AppendColumns org.apache.spark.sql.DataFrameSuite$$Lambda$4930/19529195347ce07f47, createexternalrow(a#11209, b#11210, c#11211, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), [input[0, int, false] AS value#11252]
   :        +- *(2) Sort [a#11209 ASC NULLS FIRST], false, 0
   :           +- Exchange hashpartitioning(a#11209, 5), false, [id=#10297]
   :              +- *(1) Project [_1#11202 AS a#11209, _2#11203 AS b#11210, _3#11204 AS c#11211]
   :                 +- *(1) LocalTableScan [_1#11202, _2#11203, _3#11204]
   +- *(6) Sort [value#11254 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(value#11254, 5), true, [id=#10312]
         +- AppendColumns org.apache.spark.sql.DataFrameSuite$$Lambda$4932/15265288491f0e0c1f, createexternalrow(a#11225, b#11226, c#11227, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), [input[0, int, false] AS value#11254]
            +- *(5) Sort [a#11225 ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(a#11225, 5), false, [id=#10307]
                  +- *(4) Project [_1#11218 AS a#11225, _2#11219 AS b#11226, _3#11220 AS c#11227]
                     +- *(4) LocalTableScan [_1#11218, _2#11219, _3#11220]
```

### Does this PR introduce any user-facing change?

Yes, this adds a new `as` API to RelationalGroupedDataset. Users can use it to create KeyValueGroupedDataset and do cogroup.

### How was this patch tested?

Unit tests.

Closes #26509 from viirya/SPARK-29427-2.

Lead-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Co-authored-by: Liang-Chi Hsieh <liangchi@uber.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
viirya authored and dongjoon-hyun committed Nov 22, 2019
1 parent 6e581cf commit 6b0e391aa49acd5029d00fefc0c90fcdfdf88cb6
@@ -26,6 +26,7 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -129,6 +130,37 @@ class RelationalGroupedDataset protected[sql](
(inputExpr: Expression) => exprToFunc(inputExpr)
}

/**
* Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions
* of current `RelationalGroupedDataset`.
*
* @since 3.0.0
*/
def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
val keyEncoder = encoderFor[K]
val valueEncoder = encoderFor[T]

// Resolves grouping expressions.
val dummyPlan = Project(groupingExprs.map(alias), LocalRelation(df.logicalPlan.output))
val analyzedPlan = df.sparkSession.sessionState.analyzer.execute(dummyPlan)
.asInstanceOf[Project]
df.sparkSession.sessionState.analyzer.checkAnalysis(analyzedPlan)
val aliasedGroupings = analyzedPlan.projectList

// Adds the grouping expressions that are not in base DataFrame into outputs.
val addedCols = aliasedGroupings.filter(g => !df.logicalPlan.outputSet.contains(g.toAttribute))
val qe = Dataset.ofRows(
df.sparkSession,
Project(df.logicalPlan.output ++ addedCols, df.logicalPlan)).queryExecution

new KeyValueGroupedDataset(
keyEncoder,
valueEncoder,
qe,
df.logicalPlan.output,
aliasedGroupings.map(_.toAttribute))
}

/**
* (Scala-specific) Compute aggregates by specifying the column names and
* aggregate methods. The resulting `DataFrame` will also contain the grouping columns.
@@ -30,6 +30,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
@@ -2221,4 +2222,62 @@ class DataFrameSuite extends QueryTest with SharedSparkSession {
val idTuples = sampled.collect().map(row => row.getLong(0) -> row.getLong(1))
assert(idTuples.length == idTuples.toSet.size)
}

test("groupBy.as") {
val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c")
.repartition($"a", $"b").sortWithinPartitions("a", "b")
val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c")
.repartition($"a", $"b").sortWithinPartitions("a", "b")

implicit val valueEncoder = RowEncoder(df1.schema)

val df3 = df1.groupBy("a", "b").as[GroupByKey, Row]
.cogroup(df2.groupBy("a", "b").as[GroupByKey, Row]) { case (_, data1, data2) =>
data1.zip(data2).map { p =>
p._1.getInt(2) + p._2.getInt(2)
}
}.toDF

checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil)

// Assert that no extra shuffle introduced by cogroup.
val exchanges = df3.queryExecution.executedPlan.collect {
case h: ShuffleExchangeExec => h
}
assert(exchanges.size == 2)
}

test("groupBy.as: custom grouping expressions") {
val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a1", "b", "c")
.repartition($"a1", $"b").sortWithinPartitions("a1", "b")
val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a1", "b", "c")
.repartition($"a1", $"b").sortWithinPartitions("a1", "b")

implicit val valueEncoder = RowEncoder(df1.schema)

val groupedDataset1 = df1.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row]
val groupedDataset2 = df2.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row]

val df3 = groupedDataset1
.cogroup(groupedDataset2) { case (_, data1, data2) =>
data1.zip(data2).map { p =>
p._1.getInt(2) + p._2.getInt(2)
}
}.toDF

checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil)
}

test("groupBy.as: throw AnalysisException for unresolved grouping expr") {
val df = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c")

implicit val valueEncoder = RowEncoder(df.schema)

val err = intercept[AnalysisException] {
df.groupBy($"d", $"b").as[GroupByKey, Row]
}
assert(err.getMessage.contains("cannot resolve '`d`'"))
}
}

case class GroupByKey(a: Int, b: Int)
@@ -1861,6 +1861,27 @@ class DatasetSuite extends QueryTest with SharedSparkSession {
}
}
}

test("groupBy.as") {
val df1 = Seq(DoubleData(1, "one"), DoubleData(2, "two"), DoubleData( 3, "three")).toDS()
.repartition($"id").sortWithinPartitions("id")
val df2 = Seq(DoubleData(5, "one"), DoubleData(1, "two"), DoubleData( 3, "three")).toDS()
.repartition($"id").sortWithinPartitions("id")

val df3 = df1.groupBy("id").as[Int, DoubleData]
.cogroup(df2.groupBy("id").as[Int, DoubleData]) { case (key, data1, data2) =>
if (key == 1) {
Iterator(DoubleData(key, (data1 ++ data2).foldLeft("")((cur, next) => cur + next.val1)))
} else Iterator.empty
}
checkDataset(df3, DoubleData(1, "onetwo"))

// Assert that no extra shuffle introduced by cogroup.
val exchanges = df3.queryExecution.executedPlan.collect {
case h: ShuffleExchangeExec => h
}
assert(exchanges.size == 2)
}
}

object AssertExecutionId {

0 comments on commit 6b0e391

Please sign in to comment.
You can’t perform that action at this time.