From 624cb3b7340f2d6624a4bf35cc4cd9f59739fa9c Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 5 Oct 2015 11:56:44 -0400 Subject: [PATCH 1/3] [SPARK-10926][CORE] Create WeakReferenceCleaner interface that ContextCleaner extends Preparing the way for SPARK-10250 which will introduce other objects that will be cleaned via detecting their being cleaned up. --- .../org/apache/spark/ContextCleaner.scala | 116 ++++-------------- .../apache/spark/WeakReferenceCleaner.scala | 91 ++++++++++++++ .../spark/util/cleanup/CleanupTasks.scala | 41 +++++++ project/MimaExcludes.scala | 29 +++++ 4 files changed, 184 insertions(+), 93 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index d23c1533db758..a14a55ec352d3 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -17,35 +17,13 @@ package org.apache.spark -import java.lang.ref.{ReferenceQueue, WeakReference} - -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} - import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.util.Utils +import org.apache.spark.util.cleanup.{ CleanAccum, CleanBroadcast, CleanCheckpoint } +import org.apache.spark.util.cleanup.{ CleanRDD, CleanShuffle, CleanupTask } -/** - * Classes that represent cleaning tasks. - */ -private sealed trait CleanupTask -private case class CleanRDD(rddId: Int) extends CleanupTask -private case class CleanShuffle(shuffleId: Int) extends CleanupTask -private case class CleanBroadcast(broadcastId: Long) extends CleanupTask -private case class CleanAccum(accId: Long) extends CleanupTask -private case class CleanCheckpoint(rddId: Int) extends CleanupTask - -/** - * A WeakReference associated with a CleanupTask. - * - * When the referent object becomes only weakly reachable, the corresponding - * CleanupTaskWeakReference is automatically added to the given reference queue. - */ -private class CleanupTaskWeakReference( - val task: CleanupTask, - referent: AnyRef, - referenceQueue: ReferenceQueue[AnyRef]) - extends WeakReference(referent, referenceQueue) +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} /** * An asynchronous cleaner for RDD, shuffle, and broadcast state. @@ -54,18 +32,11 @@ private class CleanupTaskWeakReference( * to be processed when the associated object goes out of scope of the application. Actual * cleanup is performed in a separate daemon thread. */ -private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - - private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] - with SynchronizedBuffer[CleanupTaskWeakReference] - - private val referenceQueue = new ReferenceQueue[AnyRef] +private[spark] class ContextCleaner(sc: SparkContext) extends WeakReferenceCleaner { private val listeners = new ArrayBuffer[CleanerListener] with SynchronizedBuffer[CleanerListener] - private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - /** * Whether the cleaning thread will block on cleanup tasks (other than shuffle, which * is controlled by the `spark.cleaner.referenceTracking.blocking.shuffle` parameter). @@ -92,35 +63,11 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val blockOnShuffleCleanupTasks = sc.conf.getBoolean( "spark.cleaner.referenceTracking.blocking.shuffle", false) - @volatile private var stopped = false - /** Attach a listener object to get information of when objects are cleaned. */ def attachListener(listener: CleanerListener): Unit = { listeners += listener } - /** Start the cleaner. */ - def start(): Unit = { - cleaningThread.setDaemon(true) - cleaningThread.setName("Spark Context Cleaner") - cleaningThread.start() - } - - /** - * Stop the cleaning thread and wait until the thread has finished running its current task. - */ - def stop(): Unit = { - stopped = true - // Interrupt the cleaning thread, but wait until the current task has finished before - // doing so. This guards against the race condition where a cleaning thread may - // potentially clean similarly named variables created by a different SparkContext, - // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132). - synchronized { - cleaningThread.interrupt() - } - cleaningThread.join() - } - /** Register a RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) @@ -145,43 +92,30 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { registerForCleanup(rdd, CleanCheckpoint(parentId)) } - /** Register an object for cleanup. */ - private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { - referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + /** Keep cleaning RDD, shuffle, and broadcast state. */ + override protected def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) { + super.keepCleaning() } - /** Keep cleaning RDD, shuffle, and broadcast state. */ - private def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) { - while (!stopped) { - try { - val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) - .map(_.asInstanceOf[CleanupTaskWeakReference]) - // Synchronize here to avoid being interrupted on stop() - synchronized { - reference.map(_.task).foreach { task => - logDebug("Got cleaning task " + task) - referenceBuffer -= reference.get - task match { - case CleanRDD(rddId) => - doCleanupRDD(rddId, blocking = blockOnCleanupTasks) - case CleanShuffle(shuffleId) => - doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) - case CleanBroadcast(broadcastId) => - doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) - case CleanAccum(accId) => - doCleanupAccum(accId, blocking = blockOnCleanupTasks) - case CleanCheckpoint(rddId) => - doCleanCheckpoint(rddId) - } - } - } - } catch { - case ie: InterruptedException if stopped => // ignore - case e: Exception => logError("Error in cleaning thread", e) - } + protected def handleCleanupForSpecificTask(task: CleanupTask): Unit = { + task match { + case CleanRDD(rddId) => + doCleanupRDD(rddId, blocking = blockOnCleanupTasks) + case CleanShuffle(shuffleId) => + doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks) + case CleanBroadcast(broadcastId) => + doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + case CleanAccum(accId) => + doCleanupAccum(accId, blocking = blockOnCleanupTasks) + case CleanCheckpoint(rddId) => + doCleanCheckpoint(rddId) + case unknown => + logWarning(s"Got a cleanup task $unknown that cannot be handled by ContextCleaner,") } } + protected def cleanupThreadName(): String = "Context Cleaner" + /** Perform RDD cleanup. */ def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = { try { @@ -252,10 +186,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] } -private object ContextCleaner { - private val REF_QUEUE_POLL_TIMEOUT = 100 -} - /** * Listener class used for testing when any item has been cleaned by the Cleaner class. */ diff --git a/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala b/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala new file mode 100644 index 0000000000000..0dd6d4773dcb6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/WeakReferenceCleaner.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import java.lang.ref.ReferenceQueue + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} + +import org.apache.spark.util.cleanup.{CleanupTask, CleanupTaskWeakReference} + +/** + * Utility trait that keeps a long running thread for cleaning up weak references + * after they are GCed. Currently implemented by ContextCleaner and ExecutorCleaner + * only. + */ +private[spark] trait WeakReferenceCleaner extends Logging { + + private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] + with SynchronizedBuffer[CleanupTaskWeakReference] + + private val referenceQueue = new ReferenceQueue[AnyRef] + + private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + + private var stopped = false + + /** Start the cleaner. */ + def start(): Unit = { + cleaningThread.setDaemon(true) + cleaningThread.setName(cleanupThreadName()) + cleaningThread.start() + } + + def stop(): Unit = { + stopped = true + synchronized { + // Interrupt the cleaning thread, but wait until the current task has finished before + // doing so. This guards against the race condition where a cleaning thread may + // potentially clean similarly named variables created by a different SparkContext, + // resulting in otherwise inexplicable block-not-found exceptions (SPARK-6132). + cleaningThread.interrupt() + } + cleaningThread.join() + } + + protected def keepCleaning(): Unit = { + while (!stopped) { + try { + val reference = Option(referenceQueue.remove(WeakReferenceCleaner.REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) + // Synchronize here to avoid being interrupted on stop() + synchronized { + reference.map(_.task).foreach { task => + logDebug("Got cleaning task " + task) + referenceBuffer -= reference.get + handleCleanupForSpecificTask(task) + } + } + } catch { + case ie: InterruptedException if stopped => // ignore + case e: Exception => logError("Error in cleaning thread", e) + } + } + } + + /** Register an object for cleanup. */ + protected def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { + referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + } + + protected def handleCleanupForSpecificTask(task: CleanupTask) + protected def cleanupThreadName(): String +} + +private object WeakReferenceCleaner { + private val REF_QUEUE_POLL_TIMEOUT = 100 +} diff --git a/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala b/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala new file mode 100644 index 0000000000000..76dbdd8d3c44f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.cleanup + +import java.lang.ref.{ReferenceQueue, WeakReference} + +/** + * Classes that represent cleaning tasks. + */ +private[spark] sealed trait CleanupTask +private[spark] case class CleanRDD(rddId: Int) extends CleanupTask +private[spark] case class CleanShuffle(shuffleId: Int) extends CleanupTask +private[spark] case class CleanBroadcast(broadcastId: Long) extends CleanupTask +private[spark] case class CleanAccum(accId: Long) extends CleanupTask +private[spark] case class CleanCheckpoint(rddId: Int) extends CleanupTask + +/** + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private[spark] class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2d4d146f51339..14e1b05fcd9f5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -80,6 +80,35 @@ object MimaExcludes { "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.regression.LeastSquaresCostFun.this") + ) ++ Seq( + // Cleanup task types are marked as private but Mima also confused by this change, + // similar to SPARK-10381. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanAccum"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanAccum$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanBroadcast$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanCheckpoint"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanCheckpoint$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanupTask"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanRDD"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanRDD$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanShuffle"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanShuffle$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanupTaskWeakReference"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.CleanupTaskWeakReference$") ) case v if v.startsWith("1.5") => Seq( From 9b143e9147c1cd68b8d36e5292d2eff30c56a79f Mon Sep 17 00:00:00 2001 From: mcheah Date: Mon, 5 Oct 2015 11:58:49 -0400 Subject: [PATCH 2/3] [SPARK-10250][CORE] External group by to handle huge keys. --- .../org/apache/spark/ExecutorCleaner.scala | 55 ++++ .../scala/org/apache/spark/SparkEnv.scala | 5 + .../apache/spark/rdd/PairRDDFunctions.scala | 41 ++- .../spark/serializer/KryoSerializer.scala | 3 +- .../spark/util/cleanup/CleanupTasks.scala | 1 + .../collection/ExternalAppendOnlyMap.scala | 238 ++--------------- .../spark/util/collection/ExternalList.scala | 212 +++++++++++++++ .../SizeTrackingCompactBuffer.scala | 47 ++++ .../spark/util/collection/Spillable.scala | 20 +- .../util/collection/SpillableCollection.scala | 252 ++++++++++++++++++ .../util/collection/ExternalListSuite.scala | 150 +++++++++++ 11 files changed, 798 insertions(+), 226 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/ExecutorCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala create mode 100644 core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala create mode 100644 core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala create mode 100644 core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala new file mode 100644 index 0000000000000..716f0906e9fc3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ExecutorCleaner.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import java.io.File + +import org.apache.spark.util.cleanup.{CleanupTask, CleanExternalList} +import org.apache.spark.util.collection.ExternalList + +/** + * Asynchronous cleaner for objects created on the Executor. So far + * only supports cleaning up ExternalList objects. Equivalent to ContextCleaner + * but for objects on the Executor heap. + */ +private[spark] class ExecutorCleaner extends WeakReferenceCleaner { + + def registerExternalListForCleanup(list: ExternalList[_]): Unit = { + registerForCleanup(list, CleanExternalList(list.getBackingFileLocations())) + } + + def doCleanExternalList(paths: Iterable[String]): Unit = { + paths.map(path => new File(path)).foreach(f => { + if (f.exists()) { + val isDeleted = f.delete() + if (!isDeleted) { + logWarning(s"Failed to delete ${f.getAbsolutePath} backing ExternalList") + } + } + }) + } + + override protected def handleCleanupForSpecificTask(task: CleanupTask): Unit = { + task match { + case CleanExternalList(paths) => doCleanExternalList(paths) + case unknown => logWarning(s"Got cleanup task that cannot be" + + s" handled by ExecutorCleaner: $unknown") + } + } + + override protected def cleanupThreadName(): String = "Executor Cleaner" +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index cfde27fb2e7d3..8eb6a5ed78eac 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -72,6 +72,7 @@ class SparkEnv ( val shuffleMemoryManager: ShuffleMemoryManager, val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, + val executorCleaner: ExecutorCleaner, val conf: SparkConf) extends Logging { // TODO Remove actorSystem @@ -103,6 +104,7 @@ class SparkEnv ( if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { actorSystem.shutdown() } + executorCleaner.stop() rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -400,6 +402,8 @@ object SparkEnv extends Logging { } new ExecutorMemoryManager(allocator) } + val executorCleaner = new ExecutorCleaner + executorCleaner.start() val envInstance = new SparkEnv( executorId, @@ -420,6 +424,7 @@ object SparkEnv extends Logging { shuffleMemoryManager, executorMemoryManager, outputCommitCoordinator, + executorCleaner, conf) // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a981b63942e6d..839c62489f908 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -45,7 +45,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{ExternalSorter, ExternalList, SizeTrackingCompactBuffer, CompactBuffer} import org.apache.spark.util.random.StratifiedSamplingUtils /** @@ -507,12 +507,37 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // groupByKey shouldn't use map side combine because map side combine does not // reduce the amount of data shuffled and requires all map side data be inserted // into a hash table, leading to more objects in the old gen. - val createCombiner = (v: V) => CompactBuffer(v) - val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v - val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 - val bufs = combineByKeyWithClassTag[CompactBuffer[V]]( - createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) - bufs.asInstanceOf[RDD[(K, Iterable[V])]] + if (PairRDDFunctions.enableGroupBySpill) { + val createCombiner = (v: V) => ExternalList(v) + val mergeValue = (buf: ExternalList[V], v: V) => buf += v + val mergeCombiners = (c1: ExternalList[V], c2: ExternalList[V]) => { + c2.foreach(c => c1 += c) + c1 + } + val aggregator = new Aggregator[K, V, ExternalList[V]](createCombiner, + mergeValue, mergeCombiners) + val shuffledRdd = if (self.partitioner != partitioner) { + self.partitionBy(partitioner) + } else { + self + } + def groupOnPartition(iterator: Iterator[(K, V)]): Iterator[(K, Iterable[V])] = { + val sorter = new ExternalSorter[K, V, ExternalList[V]](aggregator = Some(aggregator)) + sorter.insertAll(iterator) + sorter.iterator.map { keyAndGroup => + (keyAndGroup._1, keyAndGroup._2.asInstanceOf[Iterable[V]]) + } + } + + shuffledRdd.mapPartitions(groupOnPartition(_), preservesPartitioning = true) + } else { + val createCombiner = (v: V) => CompactBuffer(v) + val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v + val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 + val bufs = combineByKeyWithClassTag[CompactBuffer[V]]( + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) + bufs.asInstanceOf[RDD[(K, Iterable[V])]] + } } /** @@ -1271,4 +1296,6 @@ private[spark] object PairRDDFunctions { * basis; see SPARK-4835 for more details. */ val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) + + val enableGroupBySpill = SparkEnv.get.conf.getBoolean("spark.groupBy.spill.enabled", false) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index c5195c1143a8f..8534ff769c953 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -39,7 +39,7 @@ import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{ExternalList, ExternalListSerializer, CompactBuffer} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. @@ -104,6 +104,7 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[ExternalList[_]], new ExternalListSerializer[Any]()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) diff --git a/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala b/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala index 76dbdd8d3c44f..e0fb9e131de33 100644 --- a/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala +++ b/core/src/main/scala/org/apache/spark/util/cleanup/CleanupTasks.scala @@ -27,6 +27,7 @@ private[spark] case class CleanShuffle(shuffleId: Int) extends CleanupTask private[spark] case class CleanBroadcast(broadcastId: Long) extends CleanupTask private[spark] case class CleanAccum(accId: Long) extends CleanupTask private[spark] case class CleanCheckpoint(rddId: Int) extends CleanupTask +private[spark] case class CleanExternalList(pathsToClean: Iterable[String]) extends CleanupTask /** * A WeakReference associated with a CleanupTask. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 29c5732f5a8c1..ca3cd5c901c90 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -24,14 +24,11 @@ import scala.collection.BufferedIterator import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.ByteStreams - -import org.apache.spark.{Logging, SparkEnv, TaskContext} +import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} -import org.apache.spark.storage.{BlockId, BlockManager} +import org.apache.spark.storage.{DiskBlockObjectWriter, BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator -import org.apache.spark.executor.ShuffleWriteMetrics /** * :: DeveloperApi :: @@ -69,41 +66,16 @@ class ExternalAppendOnlyMap[K, V, C]( extends Iterable[(K, C)] with Serializable with Logging - with Spillable[SizeTracker] { + with SpillableCollection[(K, C), SizeTrackingAppendOnlyMap[K, C]] { private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] - private val sparkConf = SparkEnv.get.conf - private val diskBlockManager = blockManager.diskBlockManager - - /** - * Size of object batches when reading/writing from serializers. - * - * Objects are written in batches, with each batch using its own serialization stream. This - * cuts down on the size of reference-tracking maps constructed when deserializing a stream. - * - * NOTE: Setting this too low can cause excessive copying when serializing, since some serializers - * grow internal data structures by growing + copying every time the number of objects doubles. - */ - private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) - - // Number of bytes spilled in total - private var _diskBytesSpilled = 0L - def diskBytesSpilled: Long = _diskBytesSpilled - - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - private val fileBufferSize = - sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - - // Write metrics for current spill - private var curWriteMetrics: ShuffleWriteMetrics = _ // Peak size of the in-memory map observed so far, in bytes private var _peakMemoryUsedBytes: Long = 0L def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes private val keyComparator = new HashComparator[K] - private val ser = serializer.newInstance() /** * Insert the given key and value into the map. @@ -156,70 +128,8 @@ class ExternalAppendOnlyMap[K, V, C]( insertAll(entries.iterator) } - /** - * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. - */ - override protected[this] def spill(collection: SizeTracker): Unit = { - val (blockId, file) = diskBlockManager.createTempLocalBlock() - curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) - var objectsWritten = 0 - - // List of batch sizes (bytes) in the order they are written to disk - val batchSizes = new ArrayBuffer[Long] - - // Flush the disk writer's contents to disk, and update relevant variables - def flush(): Unit = { - val w = writer - writer = null - w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten - batchSizes.append(curWriteMetrics.shuffleBytesWritten) - objectsWritten = 0 - } - - var success = false - try { - val it = currentMap.destructiveSortedIterator(keyComparator) - while (it.hasNext) { - val kv = it.next() - writer.write(kv._1, kv._2) - objectsWritten += 1 - - if (objectsWritten == serializerBatchSize) { - flush() - curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) - } - } - if (objectsWritten > 0) { - flush() - } else if (writer != null) { - val w = writer - writer = null - w.revertPartialWritesAndClose() - } - success = true - } finally { - if (!success) { - // This code path only happens if an exception was thrown above before we set success; - // close our stuff and let the exception be thrown further - if (writer != null) { - writer.revertPartialWritesAndClose() - } - if (file.exists()) { - if (!file.delete()) { - logWarning(s"Error deleting ${file}") - } - } - } - } - - spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - } - - /** - * Return an iterator that merges the in-memory map with the spilled maps. + /* + * Return an iterator that merges the in-memory map with the spilled MAPS. * If no spill has occurred, simply return the in-memory map's iterator. */ override def iterator: Iterator[(K, C)] = { @@ -383,130 +293,38 @@ class ExternalAppendOnlyMap[K, V, C]( * An iterator that returns (K, C) pairs in sorted order from an on-disk map */ private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) - extends Iterator[(K, C)] + extends DiskIterator(file, blockId, batchSizes) { - private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 - assert(file.length() == batchOffsets.last, - "File length is not equal to the last batch offset:\n" + - s" file length = ${file.length}\n" + - s" last batch offset = ${batchOffsets.last}\n" + - s" all batch offsets = ${batchOffsets.mkString(",")}" - ) - - private var batchIndex = 0 // Which batch we're in - private var fileStream: FileInputStream = null - - // An intermediate stream that reads from exactly one batch - // This guards against pre-fetching and other arbitrary behavior of higher level streams - private var deserializeStream = nextBatchStream() - private var nextItem: (K, C) = null - private var objectsRead = 0 - - /** - * Construct a stream that reads only from the next batch. - */ - private def nextBatchStream(): DeserializationStream = { - // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether - // we're still in a valid batch. - if (batchIndex < batchOffsets.length - 1) { - if (deserializeStream != null) { - deserializeStream.close() - fileStream.close() - deserializeStream = null - fileStream = null - } - - val start = batchOffsets(batchIndex) - fileStream = new FileInputStream(file) - fileStream.getChannel.position(start) - batchIndex += 1 - - val end = batchOffsets(batchIndex) - - assert(end >= start, "start = " + start + ", end = " + end + - ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) - - val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) - val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) - ser.deserializeStream(compressedStream) - } else { - // No more batches left - cleanup() - null - } + override protected def readNextItemFromStream( + deserializeStream: DeserializationStream): (K, C) = { + val k = deserializeStream.readKey().asInstanceOf[K] + val v = deserializeStream.readValue().asInstanceOf[C] + (k, v) } - /** - * Return the next (K, C) pair from the deserialization stream. - * - * If the current batch is drained, construct a stream for the next batch and read from it. - * If no more pairs are left, return null. - */ - private def readNextItem(): (K, C) = { - try { - val k = deserializeStream.readKey().asInstanceOf[K] - val c = deserializeStream.readValue().asInstanceOf[C] - val item = (k, c) - objectsRead += 1 - if (objectsRead == serializerBatchSize) { - objectsRead = 0 - deserializeStream = nextBatchStream() - } - item - } catch { - case e: EOFException => - cleanup() - null - } - } + override protected def shouldCleanupFileAfterOneIteration(): Boolean = true + } - override def hasNext: Boolean = { - if (nextItem == null) { - if (deserializeStream == null) { - return false - } - nextItem = readNextItem() - } - nextItem != null - } - override def next(): (K, C) = { - val item = if (nextItem == null) readNextItem() else nextItem - if (item == null) { - throw new NoSuchElementException - } - nextItem = null - item - } + /** Convenience function to hash the given (K, C) pair by the key. */ + private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) - private def cleanup() { - batchIndex = batchOffsets.length // Prevent reading any other batch - val ds = deserializeStream - if (ds != null) { - ds.close() - deserializeStream = null - } - if (fileStream != null) { - fileStream.close() - fileStream = null - } - if (file.exists()) { - if (!file.delete()) { - logWarning(s"Error deleting ${file}") - } - } - } + override protected def getIteratorForCurrentSpillable(): Iterator[(K, C)] = { + currentMap.destructiveSortedIterator(keyComparator) + } - val context = TaskContext.get() - // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in - // a TaskContext. - if (context != null) { - context.addTaskCompletionListener(context => cleanup()) - } + override protected def writeNextObject( + c: (K, C), + writer: DiskBlockObjectWriter): Unit = { + writer.write(c._1, c._2) } - /** Convenience function to hash the given (K, C) pair by the key. */ - private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1) + override protected def recordNextSpilledPart( + file: File, + blockId: BlockId, + batchSizes: ArrayBuffer[Long]): Unit = { + spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) + } } private[spark] object ExternalAppendOnlyMap { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala new file mode 100644 index 0000000000000..f0e4fcff81420 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalList.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.collection + +import java.io._ + +import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.{TaskContext, ExecutorCleaner, SparkEnv} + +import scala.reflect.ClassTag +import scala.collection.generic.Growable +import scala.collection.mutable.ArrayBuffer + +import com.esotericsoftware.kryo.io.{Output, Input} +import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} + +import org.apache.spark.util.collection.ExternalList._ +import org.apache.spark.serializer.DeserializationStream +import org.apache.spark.storage.{DiskBlockObjectWriter, BlockId} + + +/** + * List that can spill some of its contents to disk if its contents cannot be held in memory. + * Implementation is based heavily on `org.apache.spark.util.collection.ExternalAppendOnlyMap}` + */ +@SerialVersionUID(1L) +private[spark] class ExternalList[T](implicit var tag: ClassTag[T]) + extends Growable[T] + with Iterable[T] + with SpillableCollection[T, SizeTrackingCompactBuffer[T]] + with Serializable { + + // Var to allow rebuilding it during Java serialization + private var spilledLists = new ArrayBuffer[DiskListIterable] + private var currentInMemoryList = new SizeTrackingCompactBuffer[T]() + private var numItems = 0 + + // We don't know up front what files will need to be cleaned up from this list. + // So check after the task is completed, after which this ExternalList will be + // completely built. + private var context = TaskContext.get + if (context != null) { + context.addTaskCompletionListener(new ScheduleCleanExternalList(this)) + } + + override def size(): Int = numItems + + override def +=(value: T): this.type = { + currentInMemoryList += value + if (maybeSpill(currentInMemoryList, currentInMemoryList.estimateSize())) { + currentInMemoryList = new SizeTrackingCompactBuffer + } + numItems += 1 + this + } + + override def clear(): Unit = { + spilledLists.foreach(_.deleteBackingFile()) + spilledLists.clear() + currentInMemoryList = new SizeTrackingCompactBuffer[T]() + } + + def getBackingFileLocations(): Iterable[String] = { + val locations = new ArrayBuffer[String] + for (diskList <- spilledLists) { + locations.append(diskList.backingFilePath()) + } + return locations + } + + def registerForCleanup(): Unit = { + if (spilledLists.size > 0) { + executorCleaner.registerExternalListForCleanup(this) + } + } + + override def iterator: Iterator[T] = { + val myIt = currentInMemoryList.iterator + val allIts = spilledLists.map(_.iterator) ++ Seq(myIt) + allIts.foldLeft(Iterator[T]())(_ ++ _) + } + + private class DiskListIterable(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + extends Iterable[T] { + override def iterator: Iterator[T] = { + new DiskListIterator(file, blockId, batchSizes) + } + def deleteBackingFile(): Unit = { + if (file.exists()) { + file.delete() + } + } + def backingFilePath(): String = file.getAbsolutePath() + } + + private class DiskListIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + extends DiskIterator(file, blockId, batchSizes) { + override protected def readNextItemFromStream(deserializeStream: DeserializationStream): T = { + deserializeStream.readKey[Int]() + deserializeStream.readValue[T]() + } + + // Need to be able to iterate multiple times, so don't clean up the file every time + override protected def shouldCleanupFileAfterOneIteration(): Boolean = false + } + + @throws(classOf[IOException]) + private def writeObject(stream: ObjectOutputStream): Unit = { + stream.writeObject(tag) + stream.writeInt(this.size) + val it = this.iterator + while (it.hasNext) { + stream.writeObject(it.next) + } + } + + @throws(classOf[IOException]) + private def readObject(stream: ObjectInputStream): Unit = { + tag = stream.readObject().asInstanceOf[ClassTag[T]] + val listSize = stream.readInt() + spilledLists = new ArrayBuffer[DiskListIterable] + currentInMemoryList = new SizeTrackingCompactBuffer[T] + for(i <- 0L until listSize) { + val newItem = stream.readObject().asInstanceOf[T] + this.+=(newItem) + } + // Upon serialization, the context might have changed. So we can't just hold a single context, + // but we must retrieving the current context every time. + // Notice that in Kryo serialization this object is constructed from scratch + // and thus will look for the current TaskContext that way. + context = TaskContext.get() + if (context != null) { + context.addTaskCompletionListener(new ScheduleCleanExternalList(this)) + } + } + + override protected def getIteratorForCurrentSpillable(): Iterator[T] = { + currentInMemoryList.iterator + } + + override protected def recordNextSpilledPart( + file: File, + blockId: BlockId, + batchSizes: ArrayBuffer[Long]): Unit = { + spilledLists += new DiskListIterable(file, blockId, batchSizes) + } + override protected def writeNextObject(c: T, writer: DiskBlockObjectWriter): Unit = { + writer.write(0, c) + } +} + +/** + * Companion object for constants and singleton-references that we don't want to lose when + * Java-serializing + */ +private[spark] object ExternalList { + + private class ScheduleCleanExternalList(private var list: ExternalList[_]) + extends TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + if (list != null) { + executorCleaner.registerExternalListForCleanup(list) + // Release reference to allow GC to clean it up + list = null + } + } + } + + def apply[T: ClassTag](): ExternalList[T] = new ExternalList[T] + + def apply[T: ClassTag](value: T): ExternalList[T] = { + val buf = new ExternalList[T] + buf += value + buf + } + + private val executorCleaner: ExecutorCleaner = SparkEnv.get.executorCleaner +} + +private[spark] class ExternalListSerializer[T: ClassTag] extends KSerializer[ExternalList[T]] { + override def write(kryo: Kryo, output: Output, list: ExternalList[T]): Unit = { + output.writeInt(list.size) + val it = list.iterator + while (it.hasNext) { + kryo.writeClassAndObject(output, it.next()) + } + } + + override def read(kryo: Kryo, input: Input, clazz: Class[ExternalList[T]]): ExternalList[T] = { + val listToRead = new ExternalList[T] + val listSize = input.readInt() + for (i <- 0L until listSize) { + val newItem = kryo.readClassAndObject(input).asInstanceOf[T] + listToRead += newItem + } + listToRead + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala new file mode 100644 index 0000000000000..d923e9a9e0bd1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingCompactBuffer.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.collection + +import scala.reflect.ClassTag + +/** + * CompactBuffer that keeps track of its size via SizeTracker. + */ +private[spark] class SizeTrackingCompactBuffer[T: ClassTag] extends CompactBuffer[T] + with SizeTracker { + + override def +=(t: T): SizeTrackingCompactBuffer[T] = { + super.+=(t) + super.afterUpdate() + this + } + + override def ++=(t: TraversableOnce[T]): SizeTrackingCompactBuffer[T] = { + super.++=(t) + super.afterUpdate() + this + } +} + +private[spark] object SizeTrackingCompactBuffer { + def apply[T: ClassTag](): SizeTrackingCompactBuffer[T] = new SizeTrackingCompactBuffer[T] + + def apply[T: ClassTag](value: T): SizeTrackingCompactBuffer[T] = { + val buf = new SizeTrackingCompactBuffer[T] + buf += value + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 747ecf075a397..a710d618f3d23 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection import org.apache.spark.Logging import org.apache.spark.SparkEnv +import org.apache.spark.util.collection.Spillable._ /** * Spills contents of an in-memory collection to disk when the memory threshold @@ -39,14 +40,6 @@ private[spark] trait Spillable[C] extends Logging { // It's used for checking spilling frequency protected def addElementsRead(): Unit = { _elementsRead += 1 } - // Memory manager that can be used to acquire/release memory - private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - - // Initial threshold for the size of a collection before we start tracking its memory usage - // Exposed for testing - private[this] val initialMemoryThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) - // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold @@ -117,4 +110,15 @@ private[spark] trait Spillable[C] extends Logging { .format(threadId, org.apache.spark.util.Utils.bytesToString(size), _spillCount, if (_spillCount > 1) "s" else "")) } + +} + +private object Spillable { + // Memory manager that can be used to acquire/release memory + protected val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + + // Initial threshold for the size of a collection before we start tracking its memory usage + // Exposed for testing + protected val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala new file mode 100644 index 0000000000000..7952e59da69fa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SpillableCollection.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io.{EOFException, BufferedInputStream, FileInputStream, File} + +import scala.collection.mutable.ArrayBuffer + +import com.google.common.io.ByteStreams + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.{DeserializationStream, Serializer} +import org.apache.spark.storage.{DiskBlockManager, BlockId, DiskBlockObjectWriter, BlockManager} +import org.apache.spark.util.collection.SpillableCollection._ + +/** + * + * Collection that can spill to disk. Takes type parameters T, the iterable type, and + * C, the type of the elements returned by T's iterator. + */ +private[spark] trait SpillableCollection[C, T <: Iterable[C]] extends Spillable[T] { + // Write metrics for current spill + private var curWriteMetrics: ShuffleWriteMetrics = _ + // Number of bytes spilled in total + protected var _diskBytesSpilled = 0L + private lazy val ser = serializer.newInstance() + + def diskBytesSpilled: Long = _diskBytesSpilled + + override protected final def spill(collection: T): Unit = { + val (blockId, file) = diskBlockManager.createTempLocalBlock() + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + var objectsWritten = 0 + + // List of batch sizes (bytes) in the order they are written to disk + val batchSizes = new ArrayBuffer[Long] + + // Flush the disk writer's contents to disk, and update relevant variables + def flush(): Unit = { + val w = writer + writer = null + w.commitAndClose() + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) + objectsWritten = 0 + } + + var success = false + try { + val it = getIteratorForCurrentSpillable() + while (it.hasNext) { + val kv = it.next() + writeNextObject(kv, writer) + objectsWritten += 1 + + if (objectsWritten == serializerBatchSize) { + flush() + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + } + } + if (objectsWritten > 0) { + flush() + } else if (writer != null) { + val w = writer + writer = null + w.revertPartialWritesAndClose() + } + success = true + } finally { + if (!success) { + // This code path only happens if an exception was thrown above before we set success; + // close our stuff and let the exception be thrown further + if (writer != null) { + writer.revertPartialWritesAndClose() + } + if (file.exists()) { + if (!file.delete()) { + logWarning(s"Error deleting ${file}") + } + } + } + } + recordNextSpilledPart(file, blockId, batchSizes) + } + + + protected def getIteratorForCurrentSpillable(): Iterator[C] + protected def writeNextObject(c: C, writer: DiskBlockObjectWriter): Unit + protected def recordNextSpilledPart(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + + /** + * Iterator backed by elements from batches on disk. + */ + protected abstract class DiskIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) + extends Iterator[C] { + private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1 + assert(file.length() == batchOffsets.last, + "File length is not equal to the last batch offset:\n" + + s" file length = ${file.length}\n" + + s" last batch offset = ${batchOffsets.last}\n" + + s" all batch offsets = ${batchOffsets.mkString(",")}" + ) + + private var batchIndex = 0 // Which batch we're in + private var fileStream: FileInputStream = null + + // An intermediate stream that reads from exactly one batch + // This guards against pre-fetching and other arbitrary behavior of higher level streams + private var deserializeStream = nextBatchStream() + private var nextItem: Option[C] = None + private var objectsRead = 0 + + /** + * Construct a stream that reads only from the next batch. + */ + protected def nextBatchStream(): DeserializationStream = { + // Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether + // we're still in a valid batch. + if (batchIndex < batchOffsets.length - 1) { + if (deserializeStream != null) { + deserializeStream.close() + fileStream.close() + deserializeStream = null + fileStream = null + } + + val start = batchOffsets(batchIndex) + fileStream = new FileInputStream(file) + fileStream.getChannel.position(start) + batchIndex += 1 + + val end = batchOffsets(batchIndex) + + assert(end >= start, "start = " + start + ", end = " + end + + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) + + val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream) + ser.deserializeStream(compressedStream) + } else { + // No more batches left + cleanup() + null + } + } + + /** + * Return the next item from the deserialization stream. + * + * If the current batch is drained, construct a stream for the next batch and read from it. + * If no more items are left, return null. + */ + protected def readNextItem(): Option[C] = { + try { + val item = readNextItemFromStream(deserializeStream) + objectsRead += 1 + if (objectsRead == serializerBatchSize) { + objectsRead = 0 + deserializeStream = nextBatchStream() + } + Some(item) + } catch { + case e: EOFException => + cleanup() + None + } + } + + private def cleanup() { + batchIndex = batchOffsets.length // Prevent reading any other batch + val ds = deserializeStream + deserializeStream = null + if (ds != null) { + ds.close() + } + val fs = fileStream + fileStream = null + if (fs != null) { + fs.close() + } + if (shouldCleanupFileAfterOneIteration()) { + if (file.exists()) { + if (!file.delete()) { + logWarning(s"Error deleting ${file}") + } + } + } + } + + override def hasNext(): Boolean = { + if (!nextItem.isDefined) { + if (deserializeStream == null) { + return false + } + nextItem = readNextItem() + } + nextItem.isDefined + } + + override def next(): C = { + if (!hasNext()) { + throw new NoSuchElementException() + } + val nextValue = nextItem.get + nextItem = None + nextValue + } + + protected def readNextItemFromStream(deserializeStream: DeserializationStream): C + protected def shouldCleanupFileAfterOneIteration(): Boolean + } +} + +private object SpillableCollection { + private def sparkConf(): SparkConf = SparkEnv.get.conf + private def blockManager(): BlockManager = SparkEnv.get.blockManager + private def diskBlockManager(): DiskBlockManager = blockManager.diskBlockManager + private def fileBufferSize(): Int = + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 + /** + * Size of object batches when reading/writing from serializers. + * + * Objects are written in batches, with each batch using its own serialization stream. This + * cuts down on the size of reference-tracking maps constructed when deserializing a stream. + * + * NOTE: Setting this too low can cause excessive copying when serializing, since some serializers + * grow internal data structures by growing + copying every time the number of objects doubles. + */ + private def serializerBatchSize(): Long = + sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) + + private def serializer(): Serializer = SparkEnv.get.serializer +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala new file mode 100644 index 0000000000000..9a7cccacb234b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util.collection + +import java.io.File +import java.lang.ref.WeakReference + +import scala.language.existentials +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.serializer.{KryoSerializer, JavaSerializer, SerializerInstance} +import org.apache.spark.util.collection.ExternalListSuite._ +import org.apache.spark.unsafe.memory.TaskMemoryManager + +import org.junit.Assert.{assertEquals, assertTrue, assertFalse} +import org.mockito.Mockito.mock +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +class ExternalListSuite extends SparkFunSuite with SharedSparkContext { + + override def beforeAll() { + conf.set("spark.kryoserializer.buffer.max", "2046m") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") + conf.set("spark.shuffle.spill.batchSize", "500") + conf.set("spark.shuffle.memoryFraction", "0.04") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + conf.set("spark.task.maxFailures", "1") + conf.setAppName("test") + super.beforeAll() + } + + test("Serializing and deserializing a spilled list should produce the same values") { + testSerialization(new KryoSerializer(conf).newInstance(), 4500000) + testSerialization(new JavaSerializer(conf).newInstance(), 3000) + } + + test("Lists that are cached should be accessible twice, but when unpersisted are cleaned up.") { + val rawLargeRdd = sc.parallelize(1 to totalRddSize) + val groupedRdd = rawLargeRdd.map(x => (x % numBuckets, x)).groupByKey + val cachedRdd = groupedRdd.cache() + cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) + runGC() + // GC on the Cached RDD shouldn't trigger the cleanup + cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) + def fileLocationsFromIterable(pair: (_, Iterable[Int])): Iterable[String] = { + pair._2.asInstanceOf[ExternalList[Int]].getBackingFileLocations() + } + val filePaths = cachedRdd.map(fileLocationsFromIterable).collect + filePaths.foreach(paths => { + paths.foreach(f => assertTrue(new File(f).exists())) + }) + cachedRdd.unpersist(true) + runGC() + checkFilesEventuallyRemoved(filePaths) + cachedRdd.foreach(validateList(totalRddSize, numBuckets, _)) + } + + private def checkFilesEventuallyRemoved(filePaths: Array[Iterable[String]]) { + eventually(timeout(40000 millis), interval(100 millis)) { + filePaths.foreach(paths => { + paths.foreach(f => assertFalse(new File(f).exists())) + }) + } + } + + /** Run GC and make sure it actually has run */ + private def runGC() { + val weakRef = new WeakReference(new Object()) + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + // Wait until a weak reference object has been GCed + while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + Thread.sleep(200) + } + } + + private def testSerialization[T: ClassTag]( + serializer: SerializerInstance, + numItems: Int): Unit = { + val list = new ExternalList[Int] + // Test big list for Kryo because it's fast enough to handle it + // and we want to test the case where the list would spill to disk + for (i <- 0 to numItems) { + list += i + } + createAndSetFakeTaskContext() + val bytes = serializer.serialize(list) + var readList = serializer.deserialize(bytes).asInstanceOf[ExternalList[Int]] + val originalIt = list.iterator + var readIt = readList.iterator + while (originalIt.hasNext) { + assertTrue(originalIt.next == readIt.next) + } + assertFalse (readIt.hasNext) + val filePaths = readList.getBackingFileLocations() + readList = null + readIt = null + taskContext.markTaskCompleted() + runGC() + eventually(timeout(40000 millis), interval(100 millis)) { + filePaths.foreach(path => assertFalse(new File(path).exists())) + } + TaskContext.unset() + } +} + +object ExternalListSuite { + var taskContext: TaskContextImpl = null + val totalRddSize = 2000000 + val numBuckets = 5 + + private def createAndSetFakeTaskContext(): Unit = { + taskContext = new TaskContextImpl(0, 0, 0L, 0, mock(classOf[TaskMemoryManager]), + SparkEnv.get.metricsSystem, Seq.empty[Accumulator[Long]]) + TaskContext.setTaskContext(taskContext) + } + + private def validateList(totalRddSize: Int, numBuckets: Int, kv: (Int, Iterable[Int])): Unit = { + var numItems = 0 + for (valsInBucket <- kv._2) { + numItems += 1 + // Can't use scala assertions because including assert statements makes closures + // not serializable. + assertEquals(s"Value $valsInBucket should not be" + + s" in bucket ${kv._1}", valsInBucket % numBuckets, kv._1) + } + assertEquals(s"Number of items in bucket ${kv._1} is incorrect.", + totalRddSize / numBuckets, numItems) + } +} + + + From 860811d083335d825da088361eee8350d980088d Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 7 Oct 2015 14:50:01 -0400 Subject: [PATCH 3/3] Turn on group-by-key spill for ExternalListSuite --- .../org/apache/spark/util/collection/ExternalListSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala index 9a7cccacb234b..1b923bb2bdde0 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalListSuite.scala @@ -36,10 +36,11 @@ class ExternalListSuite extends SparkFunSuite with SharedSparkContext { override def beforeAll() { conf.set("spark.kryoserializer.buffer.max", "2046m") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") conf.set("spark.shuffle.spill.initialMemoryThreshold", "1") conf.set("spark.shuffle.spill.batchSize", "500") conf.set("spark.shuffle.memoryFraction", "0.04") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + conf.set("spark.groupBy.spill.enabled", "true") conf.set("spark.task.maxFailures", "1") conf.setAppName("test") super.beforeAll()