From 7610f2f7613050e5b32eb9314245d79c0dac7b94 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 13 May 2015 13:21:18 -0700 Subject: [PATCH] Add tests for proper cleanup of shuffle data. --- .../shuffle/sort/SortShuffleManager.scala | 2 +- .../shuffle/unsafe/UnsafeShuffleManager.scala | 26 +++++-- .../shuffle/unsafe/UnsafeShuffleSuite.scala | 72 ++++++++++++++++++- 3 files changed, 92 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 15842941daaab..d7fab351ca3b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager true } - override def shuffleBlockResolver: IndexShuffleBlockResolver = { + override val shuffleBlockResolver: IndexShuffleBlockResolver = { indexShuffleBlockResolver } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index ce684fbe59d79..f2bfef376d3ca 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -17,6 +17,9 @@ package org.apache.spark.shuffle.unsafe +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ @@ -25,7 +28,7 @@ import org.apache.spark.shuffle.sort.SortShuffleManager /** * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. */ -private class UnsafeShuffleHandle[K, V]( +private[spark] class UnsafeShuffleHandle[K, V]( shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, V]) @@ -121,8 +124,10 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage "manager; its optimized shuffles will continue to spill to disk when necessary.") } - private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + private[this] val shufflesThatFellBackToSortShuffle = + Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) + private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. @@ -158,8 +163,8 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage context: TaskContext): ShuffleWriter[K, V] = { handle match { case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) val env = SparkEnv.get - // TODO: do we need to do anything to register the shuffle here? new UnsafeShuffleWriter( env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], @@ -170,17 +175,26 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage context, env.conf) case other => + shufflesThatFellBackToSortShuffle.add(handle.shuffleId) sortShuffleManager.getWriter(handle, mapId, context) } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - // TODO: need to do something here for our unsafe path - sortShuffleManager.unregisterShuffle(shuffleId) + if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { + sortShuffleManager.unregisterShuffle(shuffleId) + } else { + Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } } - override def shuffleBlockResolver: ShuffleBlockResolver = { + override val shuffleBlockResolver: IndexShuffleBlockResolver = { sortShuffleManager.shuffleBlockResolver } diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala index e68261a730d3a..64569f1c60927 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -17,9 +17,17 @@ package org.apache.spark.shuffle.unsafe -import org.apache.spark.ShuffleSuite +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. @@ -30,4 +38,66 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // shuffle records. conf.set("spark.shuffle.memoryFraction", "0.5") } + + test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } + + test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the old SortShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } }