From 91c88ebaf888cf4ca1885af5fa18d2f1c8a4f926 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 10 Oct 2015 20:19:37 -0700 Subject: [PATCH 1/3] Simplify wrapping of iterator into KVIterator. --- .../aggregate/AggregationIterator.scala | 83 ------------------- .../execution/aggregate/KVIteratorUtils.scala | 46 ++++++++++ .../aggregate/SortBasedAggregate.scala | 23 +++-- .../SortBasedAggregationIterator.scala | 44 ---------- 4 files changed, 61 insertions(+), 135 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/KVIteratorUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 5f7341e88c7c9..8e0fbd109b413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.unsafe.KVIterator import scala.collection.mutable.ArrayBuffer @@ -412,85 +411,3 @@ abstract class AggregationIterator( */ protected def newBuffer: MutableRow } - -object AggregationIterator { - def kvIterator( - groupingExpressions: Seq[NamedExpression], - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[InternalRow, InternalRow] = { - new KVIterator[InternalRow, InternalRow] { - private[this] val groupingKeyGenerator = newProjection(groupingExpressions, inputAttributes) - - private[this] var groupingKey: InternalRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): InternalRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } - - def unsafeKVIterator( - groupingExpressions: Seq[NamedExpression], - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]): KVIterator[UnsafeRow, InternalRow] = { - new KVIterator[UnsafeRow, InternalRow] { - private[this] val groupingKeyGenerator = - UnsafeProjection.create(groupingExpressions, inputAttributes) - - private[this] var groupingKey: UnsafeRow = _ - - private[this] var value: InternalRow = _ - - override def next(): Boolean = { - if (inputIter.hasNext) { - // Read the next input row. - val inputRow = inputIter.next() - // Get groupingKey based on groupingExpressions. - groupingKey = groupingKeyGenerator.apply(inputRow) - // The value is the inputRow. - value = inputRow - true - } else { - false - } - } - - override def getKey(): UnsafeRow = { - groupingKey - } - - override def getValue(): InternalRow = { - value - } - - override def close(): Unit = { - // Do nothing - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/KVIteratorUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/KVIteratorUtils.scala new file mode 100644 index 0000000000000..02fc5c5871b09 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/KVIteratorUtils.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.KVIterator + +object KVIteratorUtils { + def fromIterator( + inputIter: Iterator[InternalRow], + keyProjection: InternalRow => InternalRow, + valueProjection: InternalRow => InternalRow): KVIterator[InternalRow, InternalRow] = { + new KVIterator[InternalRow, InternalRow] { + private[this] var key: InternalRow = _ + private[this] var value: InternalRow = _ + override def getKey: InternalRow = key + override def getValue: InternalRow = value + override def close(): Unit = { /* Do nothing */ } + override def next(): Boolean = { + if (inputIter.hasNext) { + val inputRow = inputIter.next() + key = keyProjection(inputRow) + value = valueProjection(inputRow) + true + } else { + false + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index f4c14a9b3556f..6a26c27c76b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} -import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -79,18 +78,26 @@ case class SortBasedAggregate( // so return an empty iterator. Iterator[InternalRow]() } else { - val outputIter = SortBasedAggregationIterator.createFromInputIterator( - groupingExpressions, + val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) { + UnsafeProjection.create(groupingExpressions, child.output) + } else { + newProjection(groupingExpressions, child.output) + } + val kvIterator = KVIteratorUtils.fromIterator( + iter, + keyProjection = groupingKeyProjection, + valueProjection = identity) + val outputIter = new SortBasedAggregationIterator( + groupingExpressions.map(_.toAttribute), + child.output, + kvIterator, nonCompleteAggregateExpressions, nonCompleteAggregateAttributes, completeAggregateExpressions, completeAggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection _, - newProjection _, - child.output, - iter, + newMutableProjection, outputsUnsafeRows, numInputRows, numOutputRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index a9e5d175bf895..f78c0058ac34a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -170,47 +170,3 @@ class SortBasedAggregationIterator( generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) } } - -object SortBasedAggregationIterator { - // scalastyle:off - def createFromInputIterator( - groupingExprs: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], - nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], - completeAggregateAttributes: Seq[Attribute], - initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - newProjection: (Seq[Expression], Seq[Attribute]) => Projection, - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow], - outputsUnsafeRows: Boolean, - numInputRows: LongSQLMetric, - numOutputRows: LongSQLMetric): SortBasedAggregationIterator = { - val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) { - AggregationIterator.unsafeKVIterator( - groupingExprs, - inputAttributes, - inputIter).asInstanceOf[KVIterator[InternalRow, InternalRow]] - } else { - AggregationIterator.kvIterator(groupingExprs, newProjection, inputAttributes, inputIter) - } - - new SortBasedAggregationIterator( - groupingExprs.map(_.toAttribute), - inputAttributes, - kvIterator, - nonCompleteAggregateExpressions, - nonCompleteAggregateAttributes, - completeAggregateExpressions, - completeAggregateAttributes, - initialInputBufferOffset, - resultExpressions, - newMutableProjection, - outputsUnsafeRows, - numInputRows, - numOutputRows) - } - // scalastyle:on -} From 2b8bf947b0434146131b3af4a314cc89ace8a874 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 10 Oct 2015 21:17:29 -0700 Subject: [PATCH 2/3] Remove KVIterator from SortBasedAggregationIterator --- .../execution/aggregate/KVIteratorUtils.scala | 46 ------------------- .../aggregate/SortBasedAggregate.scala | 9 ++-- .../SortBasedAggregationIterator.scala | 46 +++++++++---------- 3 files changed, 25 insertions(+), 76 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/KVIteratorUtils.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/KVIteratorUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/KVIteratorUtils.scala deleted file mode 100644 index 02fc5c5871b09..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/KVIteratorUtils.scala +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.aggregate - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.unsafe.KVIterator - -object KVIteratorUtils { - def fromIterator( - inputIter: Iterator[InternalRow], - keyProjection: InternalRow => InternalRow, - valueProjection: InternalRow => InternalRow): KVIterator[InternalRow, InternalRow] = { - new KVIterator[InternalRow, InternalRow] { - private[this] var key: InternalRow = _ - private[this] var value: InternalRow = _ - override def getKey: InternalRow = key - override def getValue: InternalRow = value - override def close(): Unit = { /* Do nothing */ } - override def next(): Boolean = { - if (inputIter.hasNext) { - val inputRow = inputIter.next() - key = keyProjection(inputRow) - value = valueProjection(inputRow) - true - } else { - false - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 6a26c27c76b57..4d37106e007f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -81,16 +81,13 @@ case class SortBasedAggregate( val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) { UnsafeProjection.create(groupingExpressions, child.output) } else { - newProjection(groupingExpressions, child.output) + newMutableProjection(groupingExpressions, child.output)() } - val kvIterator = KVIteratorUtils.fromIterator( - iter, - keyProjection = groupingKeyProjection, - valueProjection = identity) val outputIter = new SortBasedAggregationIterator( + groupingKeyProjection, groupingExpressions.map(_.toAttribute), child.output, - kvIterator, + iter, nonCompleteAggregateExpressions, nonCompleteAggregateAttributes, completeAggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index f78c0058ac34a..2d480f98f5d3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -21,16 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.unsafe.KVIterator /** * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been * sorted by values of [[groupingKeyAttributes]]. */ class SortBasedAggregationIterator( + groupingKeyProjection: InternalRow => InternalRow, groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], - inputKVIterator: KVIterator[InternalRow, InternalRow], + inputIterator: Iterator[InternalRow], nonCompleteAggregateExpressions: Seq[AggregateExpression2], nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], @@ -90,6 +90,22 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + protected def initialize(): Unit = { + if (inputIterator.hasNext) { + initializeBuffer(sortBasedAggregationBuffer) + val inputRow = inputIterator.next() + nextGroupingKey = groupingKeyProjection(inputRow).copy() + firstRowInNextGroup = inputRow.copy() + numInputRows += 1 + sortedInputHasNewGroup = true + } else { + // This inputIter is empty. + sortedInputHasNewGroup = false + } + } + + initialize() + /** Processes rows in the current group. It will stop when it find a new group. */ protected def processCurrentSortedGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -101,18 +117,16 @@ class SortBasedAggregationIterator( // The search will stop when we see the next group or there is no // input row left in the iter. - var hasNext = inputKVIterator.next() - while (!findNextPartition && hasNext) { + while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. - val groupingKey = inputKVIterator.getKey - val currentRow = inputKVIterator.getValue + val inputRow = inputIterator.next() + val groupingKey = groupingKeyProjection(inputRow).copy() + val currentRow = inputRow.copy() numInputRows += 1 // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { processRow(sortBasedAggregationBuffer, currentRow) - - hasNext = inputKVIterator.next() } else { // We find a new group. findNextPartition = true @@ -149,22 +163,6 @@ class SortBasedAggregationIterator( } } - protected def initialize(): Unit = { - if (inputKVIterator.next()) { - initializeBuffer(sortBasedAggregationBuffer) - - nextGroupingKey = inputKVIterator.getKey().copy() - firstRowInNextGroup = inputKVIterator.getValue().copy() - numInputRows += 1 - sortedInputHasNewGroup = true - } else { - // This inputIter is empty. - sortedInputHasNewGroup = false - } - } - - initialize() - def outputForEmptyGroupingKeyWithoutInput(): InternalRow = { initializeBuffer(sortBasedAggregationBuffer) generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer) From ef538621f7415a8502615181f6fb08721a809ff6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 11 Oct 2015 13:20:19 -0700 Subject: [PATCH 3/3] Remove unnecessary copy() --- .../execution/aggregate/SortBasedAggregationIterator.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 2d480f98f5d3d..64c673064f576 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -119,9 +119,8 @@ class SortBasedAggregationIterator( // input row left in the iter. while (!findNextPartition && inputIterator.hasNext) { // Get the grouping key. - val inputRow = inputIterator.next() - val groupingKey = groupingKeyProjection(inputRow).copy() - val currentRow = inputRow.copy() + val currentRow = inputIterator.next() + val groupingKey = groupingKeyProjection(currentRow) numInputRows += 1 // Check if the current row belongs the current input row.