Skip to content

Commit

Permalink
[SPARK-20381][SQL] Add SQL metrics of numOutputRows for ObjectHashAgg…
Browse files Browse the repository at this point in the history
…regateExec

## What changes were proposed in this pull request?

ObjectHashAggregateExec is missing numOutputRows, add this metrics for it.

## How was this patch tested?

Added unit tests for the new metrics.

Author: Yucai <yucai.yu@intel.com>

Closes #17678 from yucai/objectAgg_numOutputRows.
  • Loading branch information
Yucai authored and gatorsmile committed May 5, 2017
1 parent b9ad2d1 commit 41439fd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering}
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
Expand All @@ -39,7 +40,8 @@ class ObjectAggregationIterator(
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
originalInputAttributes: Seq[Attribute],
inputRows: Iterator[InternalRow],
fallbackCountThreshold: Int)
fallbackCountThreshold: Int,
numOutputRows: SQLMetric)
extends AggregationIterator(
groupingExpressions,
originalInputAttributes,
Expand Down Expand Up @@ -83,7 +85,9 @@ class ObjectAggregationIterator(

override final def next(): UnsafeRow = {
val entry = aggBufferIterator.next()
generateOutput(entry.groupingKey, entry.aggregationBuffer)
val res = generateOutput(entry.groupingKey, entry.aggregationBuffer)
numOutputRows += 1
res
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ case class ObjectHashAggregateExec(
newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
child.output,
iter,
fallbackCountThreshold)
fallbackCountThreshold,
numOutputRows)
if (!hasInput && groupingExpressions.isEmpty) {
numOutputRows += 1
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}

test("ObjectHashAggregate metrics") {
// Assume the execution plan is
// ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1)
// -> ObjectHashAggregate(nodeId = 0)
val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("ObjectHashAggregate", Map("number of output rows" -> 2L)),
0L -> ("ObjectHashAggregate", Map("number of output rows" -> 1L)))
)

// 2 partitions and each partition contains 2 keys
val df2 = testData2.groupBy('a).agg(collect_set('a))
testSparkPlanMetrics(df2, 1, Map(
2L -> ("ObjectHashAggregate", Map("number of output rows" -> 4L)),
0L -> ("ObjectHashAggregate", Map("number of output rows" -> 3L)))
)
}

test("Sort metrics") {
// Assume the execution plan is
// WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1))
Expand Down

0 comments on commit 41439fd

Please sign in to comment.