Skip to content

Commit

Permalink
Update for InternalRow refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 6, 2015
1 parent 269cf86 commit d468a88
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.sql.AbstractScalaRowIterator;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
import org.apache.spark.sql.types.StructType;
Expand All @@ -43,22 +43,22 @@ final class UnsafeExternalRowSorter {
private final UnsafeRowConverter rowConverter;
private final RowComparator rowComparator;
private final PrefixComparator prefixComparator;
private final Function1<Row, Long> prefixComputer;
private final Function1<InternalRow, Long> prefixComputer;

public UnsafeExternalRowSorter(
StructType schema,
Ordering<Row> ordering,
Ordering<InternalRow> ordering,
PrefixComparator prefixComparator,
// TODO: if possible, avoid this boxing of the return value
Function1<Row, Long> prefixComputer) {
Function1<InternalRow, Long> prefixComputer) {
this.schema = schema;
this.rowConverter = new UnsafeRowConverter(schema);
this.rowComparator = new RowComparator(ordering, schema);
this.prefixComparator = prefixComparator;
this.prefixComputer = prefixComputer;
}

public Iterator<Row> sort(Iterator<Row> inputIterator) throws IOException {
public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
final SparkEnv sparkEnv = SparkEnv.get();
final TaskContext taskContext = TaskContext.get();
byte[] rowConversionBuffer = new byte[1024 * 8];
Expand All @@ -74,7 +74,7 @@ public Iterator<Row> sort(Iterator<Row> inputIterator) throws IOException {
);
try {
while (inputIterator.hasNext()) {
final Row row = inputIterator.next();
final InternalRow row = inputIterator.next();
final int sizeRequirement = rowConverter.getSizeRequirement(row);
if (sizeRequirement > rowConversionBuffer.length) {
rowConversionBuffer = new byte[sizeRequirement];
Expand Down Expand Up @@ -108,7 +108,7 @@ public boolean hasNext() {
}

@Override
public Row next() {
public InternalRow next() {
try {
sortedIterator.loadNext();
if (hasNext()) {
Expand Down Expand Up @@ -150,12 +150,12 @@ public Row next() {

private static final class RowComparator extends RecordComparator {
private final StructType schema;
private final Ordering<Row> ordering;
private final Ordering<InternalRow> ordering;
private final int numFields;
private final UnsafeRow row1 = new UnsafeRow();
private final UnsafeRow row2 = new UnsafeRow();

public RowComparator(Ordering<Row> ordering, StructType schema) {
public RowComparator(Ordering<InternalRow> ordering, StructType schema) {
this.schema = schema;
this.numFields = schema.length();
this.ordering = ordering;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.InternalRow

/**
* Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator
* class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to
* `Row` in order to work around a spurious IntelliJ compiler error.
*/
private[spark] abstract class AbstractScalaRowIterator extends Iterator[Row]
private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow]
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution

import scala.util.control.NonFatal

import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.serializer.Serializer
Expand All @@ -35,16 +34,6 @@ import org.apache.spark.sql.types.DataType
import org.apache.spark.util.MutablePair
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}

object Exchange {
/**
* Returns true when the ordering expressions are a subset of the key.
* if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]].
*/
def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
}
}

/**
* :: DeveloperApi ::
* Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
Expand Down Expand Up @@ -194,9 +183,6 @@ case class Exchange(
}
}
val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part)
if (newOrdering.nonEmpty) {
shuffled.setKeyOrdering(keyOrdering)
}
shuffled.setSerializer(serializer)
shuffled.map(_._2)

Expand Down Expand Up @@ -317,23 +303,20 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
child
}

val withSort = if (needSort) {
// TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow
// supports the given schema.
val supportsUnsafeRowConversion: Boolean = try {
new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray)
true
} catch {
case NonFatal(e) =>
false
}
if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) {
UnsafeExternalSort(rowOrdering, global = false, withShuffle)
} else if (sqlContext.conf.externalSortEnabled) {
ExternalSort(rowOrdering, global = false, withShuffle)
} else {
Sort(rowOrdering, global = false, withShuffle)
}
val withSort = if (needSort) {
// TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow
// supports the given schema.
val supportsUnsafeRowConversion: Boolean = try {
new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray)
true
} catch {
case NonFatal(e) =>
false
}
if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) {
UnsafeExternalSort(rowOrdering, global = false, withShuffle)
} else if (sqlContext.conf.externalSortEnabled) {
ExternalSort(rowOrdering, global = false, withShuffle)
} else {
Sort(rowOrdering, global = false, withShuffle)
}
Expand Down Expand Up @@ -364,18 +347,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
// TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow
// supports the given schema.
val supportsUnsafeRowConversion: Boolean = try {
new UnsafeRowConverter(child.schema.map(_.dataType).toArray)
true
} catch {
case NonFatal(e) =>
false
}
if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) {
UnsafeExternalSort(rowOrdering, global = false, child)
} else if (sqlContext.conf.externalSortEnabled) {
if (sqlContext.conf.externalSortEnabled) {
ExternalSort(rowOrdering, global = false, child)
} else {
Sort(rowOrdering, global = false, child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,15 @@ case class UnsafeExternalSort(
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

protected override def doExecute(): RDD[Row] = attachTree(this, "sort") {
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
assert (codegenEnabled)
def doSort(iterator: Iterator[Row]): Iterator[Row] = {
def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val ordering = newOrdering(sortOrder, child.output)
val prefixComparator = new PrefixComparator {
override def compare(prefix1: Long, prefix2: Long): Int = 0
}
// TODO: do real prefix comparsion. For dev/testing purposes, this is a dummy implementation.
def prefixComputer(row: Row): Long = 0
def prefixComputer(row: InternalRow): Long = 0
new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator)
}
child.execute().mapPartitions(doSort, preservesPartitioning = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.util.NoSuchElementException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.collection.CompactBuffer
Expand Down Expand Up @@ -64,24 +63,24 @@ case class SortMergeJoin(
val rightResults = right.execute().map(_.copy())

leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
new Iterator[Row] {
new Iterator[InternalRow] {
// Mutable per row objects.
private[this] val joinRow = new JoinedRow5
private[this] var leftElement: Row = _
private[this] var rightElement: Row = _
private[this] var leftKey: Row = _
private[this] var rightKey: Row = _
private[this] var rightMatches: CompactBuffer[Row] = _
private[this] var leftElement: InternalRow = _
private[this] var rightElement: InternalRow = _
private[this] var leftKey: InternalRow = _
private[this] var rightKey: InternalRow = _
private[this] var rightMatches: CompactBuffer[InternalRow] = _
private[this] var rightPosition: Int = -1
private[this] var stop: Boolean = false
private[this] var matchKey: Row = _
private[this] var matchKey: InternalRow = _

// initialize iterator
initialize()

override final def hasNext: Boolean = nextMatchingPair()

override final def next(): Row = {
override final def next(): InternalRow = {
if (hasNext) {
// we are using the buffered right rows and run down left iterator
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
Expand Down Expand Up @@ -144,7 +143,7 @@ case class SortMergeJoin(
fetchLeft()
}
}
rightMatches = new CompactBuffer[Row]()
rightMatches = new CompactBuffer[InternalRow]()
if (stop) {
stop = false
// iterate the right side to buffer all rows that matches
Expand Down

0 comments on commit d468a88

Please sign in to comment.