Skip to content

Commit

Permalink
Expand serializer API and use new function to help control when new U…
Browse files Browse the repository at this point in the history
…nsafeShuffle path is used.
  • Loading branch information
JoshRosen committed May 1, 2015
1 parent e267cee commit e2d96ca
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ class KryoSerializer(conf: SparkConf)
override def newInstance(): SerializerInstance = {
new KryoSerializerInstance(this)
}

override def supportsRelocationOfSerializedObjects: Boolean = {
// TODO: we should have a citation / explanatory comment here clarifying _why_ this is the case
newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset()
}
}

private[spark]
Expand Down
26 changes: 25 additions & 1 deletion core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import scala.reflect.ClassTag

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}

/**
Expand Down Expand Up @@ -63,6 +63,30 @@ abstract class Serializer {

/** Creates a new [[SerializerInstance]]. */
def newInstance(): SerializerInstance

/**
* Returns true if this serializer supports relocation of its serialized objects and false
* otherwise. This should return true if and only if reordering the bytes of serialized objects
* in serialization stream output results in re-ordered input that can be read with the
* deserializer. For instance, the following should work if the serializer supports relocation:
*
* serOut.open()
* position = 0
* serOut.write(obj1)
* serOut.flush()
* position = # of bytes writen to stream so far
* obj1Bytes = [bytes 0 through position of stream]
* serOut.write(obj2)
* serOut.flush
* position2 = # of bytes written to stream so far
* obj2Bytes = bytes[position through position2 of stream]
*
* serIn.open([obj2bytes] concatenate [obj1bytes]) should return (obj2, obj1)
*
* See SPARK-7311 for more discussion.
*/
@Experimental
def supportsRelocationOfSerializedObjects: Boolean = false
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util

import com.esotericsoftware.kryo.io.ByteBufferOutputStream

import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
Expand All @@ -34,17 +34,31 @@ import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager}
import org.apache.spark.unsafe.sort.UnsafeSorter
import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator}

private[spark] class UnsafeShuffleHandle[K, V](
private class UnsafeShuffleHandle[K, V](
shuffleId: Int,
override val numMaps: Int,
override val dependency: ShuffleDependency[K, V, V])
extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
require(UnsafeShuffleManager.canUseUnsafeShuffle(dependency))
}

private[spark] object UnsafeShuffleManager {
private[spark] object UnsafeShuffleManager extends Logging {
def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
dependency.aggregator.isEmpty && dependency.keyOrdering.isEmpty
val shufId = dependency.shuffleId
val serializer = Serializer.getSerializer(dependency.serializer)
if (!serializer.supportsRelocationOfSerializedObjects) {
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
s"${serializer.getClass.getName}, does not support object relocation")
false
} else if (dependency.aggregator.isDefined) {
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
false
} else if (dependency.keyOrdering.isDefined) {
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined")
false
} else {
log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
true
}
}
}

Expand Down Expand Up @@ -73,15 +87,13 @@ private object PartitionerPrefixComparator extends PrefixComparator {
}
}

private[spark] class UnsafeShuffleWriter[K, V](
private class UnsafeShuffleWriter[K, V](
shuffleBlockManager: IndexShuffleBlockManager,
handle: UnsafeShuffleHandle[K, V],
mapId: Int,
context: TaskContext)
extends ShuffleWriter[K, V] {

println("Construcing a new UnsafeShuffleWriter")

private[this] val memoryManager: TaskMemoryManager = context.taskMemoryManager()

private[this] val dep = handle.dependency
Expand Down Expand Up @@ -158,7 +170,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
currentPagePosition += 8
println("The stored record length is " + serializedRecordSize)
PlatformDependent.UNSAFE.putLong(
currentPage.getBaseObject, currentPagePosition, serializedRecordSize)
currentPagePosition += 8
Expand All @@ -169,7 +180,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
currentPagePosition,
serializedRecordSize)
currentPagePosition += serializedRecordSize
println("After writing record, current page position is " + currentPagePosition)
sorter.insertRecord(newRecordAddress)

// Reset for writing the next record
Expand All @@ -195,8 +205,10 @@ private[spark] class UnsafeShuffleWriter[K, V](
// TODO: don't close and re-open file handles so often; this could be inefficient

def closePartition(): Unit = {
writer.commitAndClose()
partitionLengths(currentPartition) = writer.fileSegment().length
if (writer != null) {
writer.commitAndClose()
partitionLengths(currentPartition) = writer.fileSegment().length
}
}

def switchToPartition(newPartition: Int): Unit = {
Expand All @@ -219,8 +231,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer)
val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer)
val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt
println("Base offset is " + baseOffset)
println("Record length is " + recordLength)
// TODO: need to have a way to figure out whether a serializer supports relocation of
// serialized objects or not. Sandy also ran into this in his patch (see
// https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
Expand All @@ -244,12 +254,8 @@ private[spark] class UnsafeShuffleWriter[K, V](

/** Write a sequence of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
println("Opened writer!")

val sortedIterator = sortRecords(records)
val partitionLengths = writeSortedRecordsToFile(sortedIterator)

println("Partition lengths are " + partitionLengths.toSeq)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
Expand All @@ -264,7 +270,6 @@ private[spark] class UnsafeShuffleWriter[K, V](

/** Close this writer, passing along whether the map completed */
override def stop(success: Boolean): Option[MapStatus] = {
println("Stopping unsafeshufflewriter")
try {
if (stopping) {
None
Expand Down Expand Up @@ -300,7 +305,6 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
println("Opening unsafeShuffleWriter")
new UnsafeShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ private[spark] class ExternalSorter[K, V, C](
private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
private val useSerializedPairBuffer =
!ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
ser.isInstanceOf[KryoSerializer] &&
serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset
ser.supportsRelocationOfSerializedObjects

// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
Expand Down

0 comments on commit e2d96ca

Please sign in to comment.