Skip to content

Commit

Permalink
[SPARK-13926] Automatically use Kryo serializer when shuffling RDDs w…
Browse files Browse the repository at this point in the history
…ith simple types

Because ClassTags are available when constructing ShuffledRDD we can use them to automatically use Kryo for shuffle serialization when the RDD's types are known to be compatible with Kryo.

This patch introduces `SerializerManager`, a component which picks the "best" serializer for a shuffle given the elements' ClassTags. It will automatically pick a Kryo serializer for ShuffledRDDs whose key, value, and/or combiner types are primitives, arrays of primitives, or strings. In the future we can use this class as a narrow extension point to integrate specialized serializers for other types, such as ByteBuffers.

In a planned followup patch, I will extend the BlockManager APIs so that we're able to use similar automatic serializer selection when caching RDDs (this is a little trickier because the ClassTags need to be threaded through many more places).

Author: Josh Rosen <joshrosen@databricks.com>

Closes #11755 from JoshRosen/automatically-pick-best-serializer.
  • Loading branch information
JoshRosen authored and rxin committed Mar 17, 2016
1 parent d1c193a commit de1a84e
Show file tree
Hide file tree
Showing 21 changed files with 135 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public BypassMergeSortShuffleWriter(
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics();
this.serializer = Serializer.getSerializer(dep.serializer());
this.serializer = dep.serializer();
this.shuffleBlockResolver = shuffleBlockResolver;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public UnsafeShuffleWriter(
this.mapId = mapId;
final ShuffleDependency<K, V, V> dep = handle.dependency();
this.shuffleId = dep.shuffleId();
this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
this.serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics();
this.taskContext = taskContext;
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
*
* @param _rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If not set
* explicitly then the default serializer, as specified by `spark.serializer`
* config option, will be used.
* @param keyOrdering key ordering for RDD's shuffles
* @param aggregator map/reduce-side aggregator for RDD's shuffle
* @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
Expand All @@ -70,7 +70,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
@transient private val _rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Option[Serializer] = None,
val serializer: Serializer = SparkEnv.get.serializer,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.serializer.{JavaSerializer, Serializer}
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}
Expand All @@ -56,6 +56,7 @@ class SparkEnv (
private[spark] val rpcEnv: RpcEnv,
val serializer: Serializer,
val closureSerializer: Serializer,
val serializerManager: SerializerManager,
val mapOutputTracker: MapOutputTracker,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
Expand Down Expand Up @@ -276,6 +277,8 @@ object SparkEnv extends Logging {
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")

val serializerManager = new SerializerManager(serializer, conf)

val closureSerializer = new JavaSerializer(conf)

def registerOrLookupEndpoint(
Expand Down Expand Up @@ -368,6 +371,7 @@ object SparkEnv extends Logging {
rpcEnv,
serializer,
closureSerializer,
serializerManager,
mapOutputTracker,
shuffleManager,
broadcastManager,
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ class CoGroupedRDD[K: ClassTag](
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Array[CoGroup]

private var serializer: Option[Serializer] = None
private var serializer: Serializer = SparkEnv.get.serializer

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
this.serializer = Option(serializer)
this.serializer = serializer
this
}

Expand Down
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
part: Partitioner)
extends RDD[(K, C)](prev.context, Nil) {

private var serializer: Option[Serializer] = None
private var userSpecifiedSerializer: Option[Serializer] = None

private var keyOrdering: Option[Ordering[K]] = None

Expand All @@ -54,7 +54,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
this.serializer = Option(serializer)
this.userSpecifiedSerializer = Option(serializer)
this
}

Expand All @@ -77,6 +77,14 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
}

override def getDependencies: Seq[Dependency[_]] = {
val serializer = userSpecifiedSerializer.getOrElse {
val serializerManager = SparkEnv.get.serializerManager
if (mapSideCombine) {
serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]])
} else {
serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]])
}
}
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
}

Expand Down
10 changes: 1 addition & 9 deletions core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
import org.apache.spark.serializer.Serializer

/**
* An optimized version of cogroup for set difference/subtraction.
Expand All @@ -54,13 +53,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {

private var serializer: Option[Serializer] = None

/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
this.serializer = Option(serializer)
this
}

override def getDependencies: Seq[Dependency[_]] = {
def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]])
Expand All @@ -70,7 +62,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
new ShuffleDependency[T1, T2, Any](rdd, part, serializer)
new ShuffleDependency[T1, T2, Any](rdd, part)
}
}
Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2))
Expand Down
12 changes: 0 additions & 12 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,6 @@ abstract class Serializer {
}


@DeveloperApi
object Serializer {
def getSerializer(serializer: Serializer): Serializer = {
if (serializer == null) SparkEnv.get.serializer else serializer
}

def getSerializer(serializer: Option[Serializer]): Serializer = {
serializer.getOrElse(SparkEnv.get.serializer)
}
}


/**
* :: DeveloperApi ::
* An instance of a serializer, for use by one thread at a time.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer

import scala.reflect.ClassTag

import org.apache.spark.SparkConf

/**
* Component that selects which [[Serializer]] to use for shuffles.
*/
private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {

private[this] val kryoSerializer = new KryoSerializer(conf)

private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = {
val primitiveClassTags = Set[ClassTag[_]](
ClassTag.Boolean,
ClassTag.Byte,
ClassTag.Char,
ClassTag.Double,
ClassTag.Float,
ClassTag.Int,
ClassTag.Long,
ClassTag.Null,
ClassTag.Short
)
val arrayClassTags = primitiveClassTags.map(_.wrap)
primitiveClassTags ++ arrayClassTags
}

private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]]

private def canUseKryo(ct: ClassTag[_]): Boolean = {
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
}

def getSerializer(ct: ClassTag[_]): Serializer = {
if (canUseKryo(ct)) {
kryoSerializer
} else {
defaultSerializer
}
}

/**
* Pick the best serializer for shuffling an RDD of key-value pairs.
*/
def getSerializer(keyClassTag: ClassTag[_], valueClassTag: ClassTag[_]): Serializer = {
if (canUseKryo(keyClassTag) && canUseKryo(valueClassTag)) {
kryoSerializer
} else {
defaultSerializer
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
blockManager.wrapForCompression(blockId, inputStream)
}

val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()
val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
Expand Down Expand Up @@ -100,7 +99,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.io.IOException

import org.apache.spark._
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
import org.apache.spark.storage.DiskBlockObjectWriter

Expand All @@ -44,9 +43,8 @@ private[spark] class HashShuffleWriter[K, V](
private val writeMetrics = metrics.registerShuffleWriteMetrics()

private val blockManager = SparkEnv.get.blockManager
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser,
writeMetrics)
private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits,
dep.serializer, writeMetrics)

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._

/**
Expand Down Expand Up @@ -186,10 +185,9 @@ private[spark] object SortShuffleManager extends Logging {
def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
val shufId = dependency.shuffleId
val numPartitions = dependency.partitioner.numPartitions
val serializer = Serializer.getSerializer(dependency.serializer)
if (!serializer.supportsRelocationOfSerializedObjects) {
if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
s"${serializer.getClass.getName}, does not support object relocation")
s"${dependency.serializer.getClass.getName}, does not support object relocation")
false
} else if (dependency.aggregator.isDefined) {
log.debug(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Option[Serializer] = None)
serializer: Serializer = SparkEnv.get.serializer)
extends Logging
with Spillable[WritablePartitionedPairCollection[K, C]] {

Expand All @@ -107,8 +107,7 @@ private[spark] class ExternalSorter[K, V, C](

private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(serializer)
private val serInstance = ser.newInstance()
private val serInstance = serializer.newInstance()

// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public Tuple2<TempShuffleBlockId, File> answer(
});

when(taskContext.taskMetrics()).thenReturn(taskMetrics);
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
when(shuffleDep.serializer()).thenReturn(serializer);
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// Create a mocked shuffle handle to pass into HashShuffleReader.
val shuffleHandle = {
val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
when(dependency.serializer).thenReturn(Some(serializer))
when(dependency.serializer).thenReturn(serializer)
when(dependency.aggregator).thenReturn(None)
when(dependency.keyOrdering).thenReturn(None)
new BaseShuffleHandle(shuffleId, numMaps, dependency)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
dependency = dependency
)
when(dependency.partitioner).thenReturn(new HashPartitioner(7))
when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
when(dependency.serializer).thenReturn(new JavaSerializer(conf))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer(new Answer[Void] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers {

private def shuffleDep(
partitioner: Partitioner,
serializer: Option[Serializer],
serializer: Serializer,
keyOrdering: Option[Ordering[Any]],
aggregator: Option[Aggregator[Any, Any, Any]],
mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
Expand All @@ -56,7 +56,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
}

test("supported shuffle dependencies for serialized shuffle") {
val kryo = Some(new KryoSerializer(new SparkConf()))
val kryo = new KryoSerializer(new SparkConf())

assert(canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
Expand Down Expand Up @@ -88,8 +88,8 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
}

test("unsupported shuffle dependencies for serialized shuffle") {
val kryo = Some(new KryoSerializer(new SparkConf()))
val java = Some(new JavaSerializer(new SparkConf()))
val kryo = new KryoSerializer(new SparkConf())
val java = new JavaSerializer(new SparkConf())

// We only support serializers that support object relocation
assert(!canUseSerializedShuffle(shuffleDep(
Expand Down
Loading

0 comments on commit de1a84e

Please sign in to comment.