Skip to content

Commit

Permalink
Fix tracking of indices within a partition in SpillReader, and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Jul 30, 2014
1 parent 03e1006 commit a34b352
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,16 @@ private[spark] class ExternalSorter[K, V, C](
val fileStream = new FileInputStream(spill.file)
val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)

// Track which partition and which batch stream we're in
// Track which partition and which batch stream we're in. These will be the indices of
// the next element we will read. We'll also store the last partition read so that
// readNextPartition() can figure out what partition that was from.
var partitionId = 0
var indexInPartition = -1L // Just to make sure we start at index 0
var indexInPartition = 0L
var batchStreamsRead = 0
var indexInBatch = 0
var lastPartitionId = 0

skipToNextPartition()

// An intermediate stream that reads from exactly one batch
// This guards against pre-fetching and other arbitrary behavior of higher level streams
Expand All @@ -500,6 +505,18 @@ private[spark] class ExternalSorter[K, V, C](
ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
}

/**
* Update partitionId if we have reached the end of our current partition, possibly skipping
* empty partitions on the way.
*/
private def skipToNextPartition() {
while (partitionId < numPartitions &&
indexInPartition == spill.elementsPerPartition(partitionId)) {
partitionId += 1
indexInPartition = 0L
}
}

/**
* Return the next (K, C) pair from the deserialization stream and update partitionId,
* indexInPartition, indexInBatch and such to match its location.
Expand All @@ -513,6 +530,7 @@ private[spark] class ExternalSorter[K, V, C](
}
val k = deserStream.readObject().asInstanceOf[K]
val c = deserStream.readObject().asInstanceOf[C]
lastPartitionId = partitionId
// Start reading the next batch if we're done with this one
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
Expand All @@ -521,16 +539,11 @@ private[spark] class ExternalSorter[K, V, C](
deserStream = serInstance.deserializeStream(compressedStream)
indexInBatch = 0
}
// Update the partition location of the element we're reading, possibly skipping zero-length
// partitions until we get to the next non-empty one or to EOF.
// Update the partition location of the element we're reading
indexInPartition += 1
while (indexInPartition == spill.elementsPerPartition(partitionId)) {
partitionId += 1
indexInPartition = 0
}
if (partitionId == numPartitions - 1 &&
indexInPartition == spill.elementsPerPartition(partitionId) - 1) {
// This is the last element, remember that we're done
skipToNextPartition()
// If we've finished reading the last partition, remember that we're done
if (partitionId == numPartitions) {
finished = true
deserStream.close()
}
Expand All @@ -550,10 +563,10 @@ private[spark] class ExternalSorter[K, V, C](
return false
}
}
assert(partitionId >= myPartition)
assert(lastPartitionId >= myPartition)
// Check that we're still in the right partition; note that readNextItem will have returned
// null at EOF above so we would've returned false there
partitionId == myPartition
lastPartitionId == myPartition
}

override def next(): Product2[K, C] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,25 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
val sorter = new ExternalSorter[Int, Int, Int](
Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
assert(sorter.iterator.toSeq === Seq())
sorter.stop()

// Only aggregator
val sorter2 = new ExternalSorter[Int, Int, Int](
Some(agg), Some(new HashPartitioner(3)), None, None)
assert(sorter2.iterator.toSeq === Seq())
sorter2.stop()

// Only ordering
val sorter3 = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(3)), Some(ord), None)
assert(sorter3.iterator.toSeq === Seq())
sorter3.stop()

// Neither aggregator nor ordering
val sorter4 = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(3)), None, None)
assert(sorter4.iterator.toSeq === Seq())
sorter4.stop()
}

test("few elements per partition") {
Expand All @@ -73,24 +77,53 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
sorter.write(elements.iterator)
assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter.stop()

// Only aggregator
val sorter2 = new ExternalSorter[Int, Int, Int](
Some(agg), Some(new HashPartitioner(7)), None, None)
sorter2.write(elements.iterator)
assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter2.stop()

// Only ordering
val sorter3 = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(7)), Some(ord), None)
sorter3.write(elements.iterator)
assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter3.stop()

// Neither aggregator nor ordering
val sorter4 = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(7)), None, None)
sorter4.write(elements.iterator)
assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter4.stop()
}

test("empty partitions with spilling") {
val conf = new SparkConf(false)
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local", "test", conf)

val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
val ord = implicitly[Ordering[Int]]
val elements = Iterator((1, 1), (5, 5)) ++ (0 until 50000).iterator.map(x => (2, 2))

val sorter = new ExternalSorter[Int, Int, Int](
None, Some(new HashPartitioner(7)), None, None)
sorter.write(elements)
assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
assert(iter.next() === (0, Nil))
assert(iter.next() === (1, List((1, 1))))
assert(iter.next() === (2, (0 until 50000).map(x => (2, 2)).toList))
assert(iter.next() === (3, Nil))
assert(iter.next() === (4, Nil))
assert(iter.next() === (5, List((5, 5))))
assert(iter.next() === (6, Nil))
sorter.stop()
}

test("spilling in local cluster") {
Expand Down

0 comments on commit a34b352

Please sign in to comment.