Skip to content

Commit

Permalink
[SPARK-22223][SQL] ObjectHashAggregate should not introduce unnecessa…
Browse files Browse the repository at this point in the history
…ry shuffle

## What changes were proposed in this pull request?

`ObjectHashAggregateExec` should override `outputPartitioning` in order to avoid unnecessary shuffle.

## How was this patch tested?

Added Jenkins test.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #19501 from viirya/SPARK-22223.
  • Loading branch information
viirya authored and cloud-fan committed Oct 16, 2017
1 parent 13c1559 commit 0ae9649
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ case class ObjectHashAggregateExec(
}
}

override def outputPartitioning: Partitioning = child.outputPartitioning

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
val aggTime = longMetric("aggTime")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.util.Random

import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"),
Seq(Row(3, 4, 9)))
}

test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") {
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
.repartition(col("a"))

val objHashAggDF = df
.withColumn("d", expr("(a, b, c)"))
.groupBy("a", "b").agg(collect_list("d").as("e"))
.withColumn("f", expr("(b, e)"))
.groupBy("a").agg(collect_list("f").as("g"))
val aggPlan = objHashAggDF.queryExecution.executedPlan

val sortAggPlans = aggPlan.collect {
case sortAgg: SortAggregateExec => sortAgg
}
assert(sortAggPlans.isEmpty)

val objHashAggPlans = aggPlan.collect {
case objHashAgg: ObjectHashAggregateExec => objHashAgg
}
assert(objHashAggPlans.nonEmpty)

val exchangePlans = aggPlan.collect {
case shuffle: ShuffleExchangeExec => shuffle
}
assert(exchangePlans.length == 1)
}
}
}

0 comments on commit 0ae9649

Please sign in to comment.