From d468a889495b5f74285aa7cec63ef82b4888cc8f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 13 Jun 2015 18:00:35 -0700 Subject: [PATCH] Update for InternalRow refactoring --- .../execution/UnsafeExternalRowSorter.java | 18 +++--- .../spark/sql/AbstractScalaRowIterator.scala | 4 +- .../apache/spark/sql/execution/Exchange.scala | 58 +++++-------------- .../spark/sql/execution/basicOperators.scala | 6 +- .../sql/execution/joins/SortMergeJoin.scala | 19 +++--- 5 files changed, 39 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 2380c4614b5e3..7c53ea7bdac24 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -27,7 +27,7 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; import org.apache.spark.sql.types.StructType; @@ -43,14 +43,14 @@ final class UnsafeExternalRowSorter { private final UnsafeRowConverter rowConverter; private final RowComparator rowComparator; private final PrefixComparator prefixComparator; - private final Function1 prefixComputer; + private final Function1 prefixComputer; public UnsafeExternalRowSorter( StructType schema, - Ordering ordering, + Ordering ordering, PrefixComparator prefixComparator, // TODO: if possible, avoid this boxing of the return value - Function1 prefixComputer) { + Function1 prefixComputer) { this.schema = schema; this.rowConverter = new UnsafeRowConverter(schema); this.rowComparator = new RowComparator(ordering, schema); @@ -58,7 +58,7 @@ public UnsafeExternalRowSorter( this.prefixComputer = prefixComputer; } - public Iterator sort(Iterator inputIterator) throws IOException { + public Iterator sort(Iterator inputIterator) throws IOException { final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); byte[] rowConversionBuffer = new byte[1024 * 8]; @@ -74,7 +74,7 @@ public Iterator sort(Iterator inputIterator) throws IOException { ); try { while (inputIterator.hasNext()) { - final Row row = inputIterator.next(); + final InternalRow row = inputIterator.next(); final int sizeRequirement = rowConverter.getSizeRequirement(row); if (sizeRequirement > rowConversionBuffer.length) { rowConversionBuffer = new byte[sizeRequirement]; @@ -108,7 +108,7 @@ public boolean hasNext() { } @Override - public Row next() { + public InternalRow next() { try { sortedIterator.loadNext(); if (hasNext()) { @@ -150,12 +150,12 @@ public Row next() { private static final class RowComparator extends RecordComparator { private final StructType schema; - private final Ordering ordering; + private final Ordering ordering; private final int numFields; private final UnsafeRow row1 = new UnsafeRow(); private final UnsafeRow row2 = new UnsafeRow(); - public RowComparator(Ordering ordering, StructType schema) { + public RowComparator(Ordering ordering, StructType schema) { this.schema = schema; this.numFields = schema.length(); this.ordering = ordering; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala index 38d0b6ad25c9a..cfefb13e7721e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.InternalRow + /** * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to * `Row` in order to work around a spurious IntelliJ compiler error. */ -private[spark] abstract class AbstractScalaRowIterator extends Iterator[Row] +private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 8372bd0810234..6d0c97e5e23dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import scala.util.control.NonFatal -import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer @@ -35,16 +34,6 @@ import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} -object Exchange { - /** - * Returns true when the ordering expressions are a subset of the key. - * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. - */ - def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { - desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) - } -} - /** * :: DeveloperApi :: * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each @@ -194,9 +183,6 @@ case class Exchange( } } val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part) - if (newOrdering.nonEmpty) { - shuffled.setKeyOrdering(keyOrdering) - } shuffled.setSerializer(serializer) shuffled.map(_._2) @@ -317,23 +303,20 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child } - val withSort = if (needSort) { - // TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow - // supports the given schema. - val supportsUnsafeRowConversion: Boolean = try { - new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray) - true - } catch { - case NonFatal(e) => - false - } - if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { - UnsafeExternalSort(rowOrdering, global = false, withShuffle) - } else if (sqlContext.conf.externalSortEnabled) { - ExternalSort(rowOrdering, global = false, withShuffle) - } else { - Sort(rowOrdering, global = false, withShuffle) - } + val withSort = if (needSort) { + // TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow + // supports the given schema. + val supportsUnsafeRowConversion: Boolean = try { + new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray) + true + } catch { + case NonFatal(e) => + false + } + if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { + UnsafeExternalSort(rowOrdering, global = false, withShuffle) + } else if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) } else { Sort(rowOrdering, global = false, withShuffle) } @@ -364,18 +347,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ case (UnspecifiedDistribution, Seq(), child) => child case (UnspecifiedDistribution, rowOrdering, child) => - // TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow - // supports the given schema. - val supportsUnsafeRowConversion: Boolean = try { - new UnsafeRowConverter(child.schema.map(_.dataType).toArray) - true - } catch { - case NonFatal(e) => - false - } - if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) { - UnsafeExternalSort(rowOrdering, global = false, child) - } else if (sqlContext.conf.externalSortEnabled) { + if (sqlContext.conf.externalSortEnabled) { ExternalSort(rowOrdering, global = false, child) } else { Sort(rowOrdering, global = false, child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2ae683d601ede..262e14d7a859e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -268,15 +268,15 @@ case class UnsafeExternalSort( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[Row] = attachTree(this, "sort") { + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { assert (codegenEnabled) - def doSort(iterator: Iterator[Row]): Iterator[Row] = { + def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val prefixComparator = new PrefixComparator { override def compare(prefix1: Long, prefix2: Long): Int = 0 } // TODO: do real prefix comparsion. For dev/testing purposes, this is a dummy implementation. - def prefixComputer(row: Row): Long = 0 + def prefixComputer(row: InternalRow): Long = 0 new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator) } child.execute().mapPartitions(doSort, preservesPartitioning = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 0699650102a6f..2abe65a71813d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -22,7 +22,6 @@ import java.util.NoSuchElementException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.collection.CompactBuffer @@ -64,24 +63,24 @@ case class SortMergeJoin( val rightResults = right.execute().map(_.copy()) leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => - new Iterator[Row] { + new Iterator[InternalRow] { // Mutable per row objects. private[this] val joinRow = new JoinedRow5 - private[this] var leftElement: Row = _ - private[this] var rightElement: Row = _ - private[this] var leftKey: Row = _ - private[this] var rightKey: Row = _ - private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var leftElement: InternalRow = _ + private[this] var rightElement: InternalRow = _ + private[this] var leftKey: InternalRow = _ + private[this] var rightKey: InternalRow = _ + private[this] var rightMatches: CompactBuffer[InternalRow] = _ private[this] var rightPosition: Int = -1 private[this] var stop: Boolean = false - private[this] var matchKey: Row = _ + private[this] var matchKey: InternalRow = _ // initialize iterator initialize() override final def hasNext: Boolean = nextMatchingPair() - override final def next(): Row = { + override final def next(): InternalRow = { if (hasNext) { // we are using the buffered right rows and run down left iterator val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) @@ -144,7 +143,7 @@ case class SortMergeJoin( fetchLeft() } } - rightMatches = new CompactBuffer[Row]() + rightMatches = new CompactBuffer[InternalRow]() if (stop) { stop = false // iterate the right side to buffer all rows that matches