diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 04ac811ac7966..fe092683d5400 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.unsafe -import java.io.{FileOutputStream, OutputStream} import java.nio.ByteBuffer import java.util @@ -29,7 +28,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.{BlockObjectWriter, ShuffleBlockId} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} import org.apache.spark.unsafe.sort.UnsafeSorter @@ -104,7 +103,14 @@ private[spark] class UnsafeShuffleWriter[K, V]( private[this] val blockManager = SparkEnv.get.blockManager - private def sortRecords(records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private[this] val fileBufferSize = + SparkEnv.get.conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + + private[this] val serializer = Serializer.getSerializer(dep.serializer).newInstance() + + private def sortRecords( + records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = { val sorter = new UnsafeSorter( context.taskMemoryManager(), DummyRecordComparator, @@ -112,7 +118,6 @@ private[spark] class UnsafeShuffleWriter[K, V]( PartitionerPrefixComparator, 4096 // initial size ) - val serializer = Serializer.getSerializer(dep.serializer).newInstance() val PAGE_SIZE = 1024 * 1024 * 1 var currentPage: MemoryBlock = null @@ -178,32 +183,31 @@ private[spark] class UnsafeShuffleWriter[K, V]( sorter.getSortedIterator } - private def writeSortedRecordsToFile(sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = { + private def writeSortedRecordsToFile( + sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = { val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = new Array[Long](partitioner.numPartitions) var currentPartition = -1 - var prevPartitionLength: Long = 0 - var out: OutputStream = null + var writer: BlockObjectWriter = null // TODO: don't close and re-open file handles so often; this could be inefficient def closePartition(): Unit = { - out.flush() - out.close() - partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength + writer.commitAndClose() + partitionLengths(currentPartition) = writer.fileSegment().length } def switchToPartition(newPartition: Int): Unit = { - assert (newPartition > currentPartition, s"new partition $newPartition should be >= $currentPartition") + assert (newPartition > currentPartition, + s"new partition $newPartition should be >= $currentPartition") if (currentPartition != -1) { closePartition() - prevPartitionLength = partitionLengths(currentPartition) } - println(s"Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq) currentPartition = newPartition - out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) + writer = + blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics) } while (sortedRecords.hasNext) { @@ -214,18 +218,24 @@ private[spark] class UnsafeShuffleWriter[K, V]( } val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer) val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer) - val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8) + val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt println("Base offset is " + baseOffset) println("Record length is " + recordLength) // TODO: need to have a way to figure out whether a serializer supports relocation of // serialized objects or not. Sandy also ran into this in his patch (see // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might // as well just bypass this optimized code path in favor of the old one. - var i: Int = 0 - while (i < recordLength) { - out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i)) - i += 1 - } + // TODO: re-use a buffer or avoid double-buffering entirely + val arr: Array[Byte] = new Array[Byte](recordLength) + PlatformDependent.copyMemory( + baseObject, + baseOffset + 16, + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength) + writer.write(arr) + // TODO: add a test that detects whether we leave this call out: + writer.recordWritten() } closePartition()