Skip to content

Commit

Permalink
Commit failing test demonstrating bug in handling objects in spills
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 6, 2015
1 parent 41b8881 commit 7f875f9
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ private void initializeForWriting() throws IOException {
* Sort and spill the current records in response to memory pressure.
*/
@VisibleForTesting
void spill() throws IOException {
public void spill() throws IOException {
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;

import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
Expand All @@ -56,8 +55,6 @@ public class UnsafeExternalSorterSuite {

final TaskMemoryManager memoryManager =
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
// Compute key prefixes based on the records' partition ids
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
final PrefixComparator prefixComparator = new PrefixComparator() {
@Override
Expand Down Expand Up @@ -138,11 +135,8 @@ private static void insertNumber(UnsafeExternalSorter sorter, int value) throws
sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
}

/**
* Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
*/
@Test
public void testSortingOnlyByPartitionId() throws Exception {
public void testSortingOnlyByPrefix() throws Exception {

final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
memoryManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
import scala.collection.Iterator;
import scala.math.Ordering;

import com.google.common.annotations.VisibleForTesting;

import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.sql.AbstractScalaRowIterator;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
Expand All @@ -41,61 +44,70 @@ final class UnsafeExternalRowSorter {

private final StructType schema;
private final UnsafeRowConverter rowConverter;
private final RowComparator rowComparator;
private final PrefixComparator prefixComparator;
private final Function1<InternalRow, Long> prefixComputer;
private final ObjectPool objPool = new ObjectPool(128);
private final UnsafeExternalSorter sorter;
private byte[] rowConversionBuffer = new byte[1024 * 8];

public UnsafeExternalRowSorter(
StructType schema,
Ordering<InternalRow> ordering,
PrefixComparator prefixComparator,
// TODO: if possible, avoid this boxing of the return value
Function1<InternalRow, Long> prefixComputer) {
Function1<InternalRow, Long> prefixComputer) throws IOException {
this.schema = schema;
this.rowConverter = new UnsafeRowConverter(schema);
this.rowComparator = new RowComparator(ordering, schema);
this.prefixComparator = prefixComparator;
this.prefixComputer = prefixComputer;
}

public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
final SparkEnv sparkEnv = SparkEnv.get();
final TaskContext taskContext = TaskContext.get();
byte[] rowConversionBuffer = new byte[1024 * 8];
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
sorter = new UnsafeExternalSorter(
taskContext.taskMemoryManager(),
sparkEnv.shuffleMemoryManager(),
sparkEnv.blockManager(),
taskContext,
rowComparator,
new RowComparator(ordering, schema.length(), objPool),
prefixComparator,
4096,
sparkEnv.conf()
);
}

@VisibleForTesting
void insertRow(InternalRow row) throws IOException {
final int sizeRequirement = rowConverter.getSizeRequirement(row);
if (sizeRequirement > rowConversionBuffer.length) {
rowConversionBuffer = new byte[sizeRequirement];
} else {
// Zero out the buffer that's used to hold the current row. This is necessary in order
// to ensure that rows hash properly, since garbage data from the previous row could
// otherwise end up as padding in this row. As a performance optimization, we only zero
// out the portion of the buffer that we'll actually write to.
Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0);
}
final int bytesWritten = rowConverter.writeRow(
row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, objPool);
assert (bytesWritten == sizeRequirement);
final long prefix = prefixComputer.apply(row);
sorter.insertRecord(
rowConversionBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
sizeRequirement,
prefix
);
}

@VisibleForTesting
void spill() throws IOException {
sorter.spill();
}

private void cleanupResources() {
sorter.freeMemory();
}

@VisibleForTesting
Iterator<InternalRow> sort() throws IOException {
try {
while (inputIterator.hasNext()) {
final InternalRow row = inputIterator.next();
final int sizeRequirement = rowConverter.getSizeRequirement(row);
if (sizeRequirement > rowConversionBuffer.length) {
rowConversionBuffer = new byte[sizeRequirement];
} else {
// Zero out the buffer that's used to hold the current row. This is necessary in order
// to ensure that rows hash properly, since garbage data from the previous row could
// otherwise end up as padding in this row. As a performance optimization, we only zero
// out the portion of the buffer that we'll actually write to.
Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0);
}
final int bytesWritten =
rowConverter.writeRow(row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET);
assert (bytesWritten == sizeRequirement);
final long prefix = prefixComputer.apply(row);
sorter.insertRecord(
rowConversionBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
sizeRequirement,
prefix
);
}
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
return new AbstractScalaRowIterator() {

Expand All @@ -113,7 +125,7 @@ public InternalRow next() {
sortedIterator.loadNext();
if (hasNext()) {
row.pointTo(
sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, schema);
sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, objPool);
return row;
} else {
final byte[] rowDataCopy = new byte[sortedIterator.getRecordLength()];
Expand All @@ -125,14 +137,12 @@ public InternalRow next() {
sortedIterator.getRecordLength()
);
row.backingArray = rowDataCopy;
row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema);
row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, objPool);
sorter.freeMemory();
return row;
}
} catch (IOException e) {
// TODO: we need to ensure that files are cleaned properly after an exception,
// so we need better cleanup methods than freeMemory().
sorter.freeMemory();
cleanupResources();
// Scala iterators don't declare any checked exceptions, so we need to use this hack
// to re-throw the exception:
PlatformDependent.throwException(e);
Expand All @@ -141,30 +151,36 @@ public InternalRow next() {
};
};
} catch (IOException e) {
// TODO: we need to ensure that files are cleaned properly after an exception,
// so we need better cleanup methods than freeMemory().
sorter.freeMemory();
cleanupResources();
throw e;
}
}


public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
while (inputIterator.hasNext()) {
insertRow(inputIterator.next());
}
return sort();
}

private static final class RowComparator extends RecordComparator {
private final StructType schema;
private final Ordering<InternalRow> ordering;
private final int numFields;
private final ObjectPool objPool;
private final UnsafeRow row1 = new UnsafeRow();
private final UnsafeRow row2 = new UnsafeRow();

public RowComparator(Ordering<InternalRow> ordering, StructType schema) {
this.schema = schema;
this.numFields = schema.length();
public RowComparator(Ordering<InternalRow> ordering, int numFields, ObjectPool objPool) {
this.numFields = numFields;
this.ordering = ordering;
this.objPool = objPool;
}

@Override
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
row1.pointTo(baseObj1, baseOff1, numFields, schema);
row2.pointTo(baseObj2, baseOff2, numFields, schema);
row1.pointTo(baseObj1, baseOff1, numFields, objPool);
row2.pointTo(baseObj2, baseOff2, numFields, objPool);
return ordering.compare(row1, row2);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
import org.apache.spark.{SparkEnv, HashPartitioner}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
Expand Down Expand Up @@ -275,7 +274,7 @@ case class UnsafeExternalSort(
val prefixComparator = new PrefixComparator {
override def compare(prefix1: Long, prefix2: Long): Int = 0
}
// TODO: do real prefix comparsion. For dev/testing purposes, this is a dummy implementation.
// TODO: do real prefix comparison. For dev/testing purposes, this is a dummy implementation.
def prefixComputer(row: InternalRow): Long = 0
new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, AttributeReference, SortOrder}
import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType}
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.{Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.TestSQLContext

Expand Down Expand Up @@ -54,4 +59,47 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
input.sortBy(t => (t._2, t._1)))
}

test("sorting with object columns") {
// TODO: larger input data
val input = Seq(
Row("Hello", Row(1)),
Row("World", Row(2))
)

val schema = StructType(
StructField("a", StringType, nullable = false) ::
StructField("b", StructType(StructField("b", IntegerType, nullable = false) :: Nil)) ::
Nil
)

// Hack so that we don't need to pass in / mock TaskContext, SparkEnv, etc. Ultimately it would
// be better to not use this hack, but due to time constraints I have deferred this for
// followup PRs.
val sortResult = TestSQLContext.sparkContext.parallelize(input, 1).mapPartitions { iter =>
val rows = iter.toSeq
val sortOrder = SortOrder(BoundReference(0, StringType, nullable = false), Ascending)

val sorter = new UnsafeExternalRowSorter(
schema,
GenerateOrdering.generate(Seq(sortOrder), schema.toAttributes),
new PrefixComparator {
override def compare(prefix1: Long, prefix2: Long): Int = 0
},
x => 0L
)

val toCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)

sorter.insertRow(toCatalyst(input.head).asInstanceOf[InternalRow])
sorter.spill()
input.tail.foreach { row =>
sorter.insertRow(toCatalyst(row).asInstanceOf[InternalRow])
}
val sortedRowsIterator = sorter.sort()
sortedRowsIterator.map(CatalystTypeConverters.convertToScala(_, schema).asInstanceOf[Row])
}.collect()

assert(input.sortBy(t => t.getString(0)) === sortResult)
}
}

0 comments on commit 7f875f9

Please sign in to comment.