Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 1, 2015
1 parent 8e3ec20 commit 253f13e
Showing 1 changed file with 23 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ private[spark] class UnsafeShuffleWriter[K, V](
val serializer = Serializer.getSerializer(dep.serializer).newInstance()
val PAGE_SIZE = 1024 * 1024 * 1

var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE)
var currentPagePosition: Long = currentPage.getBaseOffset
var currentPage: MemoryBlock = null
var currentPagePosition: Long = PAGE_SIZE

def ensureSpaceInDataPage(spaceRequired: Long): Unit = {
if (spaceRequired > PAGE_SIZE) {
Expand All @@ -143,6 +143,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
serBufferSerStream.flush()

val serializedRecordSize = byteBuffer.position()
assert(serializedRecordSize > 0)
// TODO: we should run the partition extraction function _now_, at insert time, rather than
// requiring it to be stored alongisde the data, since this may lead to double storage
val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8
Expand All @@ -152,17 +153,17 @@ private[spark] class UnsafeShuffleWriter[K, V](
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
currentPagePosition += 8
println("The stored record length is " + byteBuffer.position())
println("The stored record length is " + serializedRecordSize)
PlatformDependent.UNSAFE.putLong(
currentPage.getBaseObject, currentPagePosition, byteBuffer.position())
currentPage.getBaseObject, currentPagePosition, serializedRecordSize)
currentPagePosition += 8
PlatformDependent.copyMemory(
serArray,
PlatformDependent.BYTE_ARRAY_OFFSET,
currentPage.getBaseObject,
currentPagePosition,
byteBuffer.position())
currentPagePosition += byteBuffer.position()
serializedRecordSize)
currentPagePosition += serializedRecordSize
println("After writing record, current page position is " + currentPagePosition)
sorter.insertRecord(newRecordAddress)

Expand Down Expand Up @@ -195,10 +196,12 @@ private[spark] class UnsafeShuffleWriter[K, V](
}

def switchToPartition(newPartition: Int): Unit = {
assert (newPartition > currentPartition, s"new partition $newPartition should be >= $currentPartition")
if (currentPartition != -1) {
closePartition()
prevPartitionLength = partitionLengths(currentPartition)
}
println(s"Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq)
currentPartition = newPartition
out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true))
}
Expand All @@ -214,11 +217,11 @@ private[spark] class UnsafeShuffleWriter[K, V](
val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8)
println("Base offset is " + baseOffset)
println("Record length is " + recordLength)
var i: Int = 0
// TODO: need to have a way to figure out whether a serializer supports relocation of
// serialized objects or not. Sandy also ran into this in his patch (see
// https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
// as well just bypass this optimized code path in favor of the old one.
var i: Int = 0
while (i < recordLength) {
out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i))
i += 1
Expand All @@ -241,6 +244,14 @@ private[spark] class UnsafeShuffleWriter[K, V](
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}

private def freeMemory(): Unit = {
val iter = allocatedPages.iterator()
while (iter.hasNext) {
memoryManager.freePage(iter.next())
iter.remove()
}
}

/** Close this writer, passing along whether the map completed */
override def stop(success: Boolean): Option[MapStatus] = {
println("Stopping unsafeshufflewriter")
Expand All @@ -249,6 +260,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
None
} else {
stopping = true
freeMemory()
if (success) {
Option(mapStatus)
} else {
Expand All @@ -258,24 +270,14 @@ private[spark] class UnsafeShuffleWriter[K, V](
}
}
} finally {
// Clean up our sorter, which may have its own intermediate files
if (!allocatedPages.isEmpty) {
val iter = allocatedPages.iterator()
while (iter.hasNext) {
memoryManager.freePage(iter.next())
iter.remove()
}
val startTime = System.nanoTime()
//sorter.stop()
context.taskMetrics().shuffleWriteMetrics.foreach(
_.incShuffleWriteTime(System.nanoTime - startTime))
}
freeMemory()
val startTime = System.nanoTime()
context.taskMetrics().shuffleWriteMetrics.foreach(
_.incShuffleWriteTime(System.nanoTime - startTime))
}
}
}



private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager {

private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
Expand Down

0 comments on commit 253f13e

Please sign in to comment.