Skip to content

Commit 78f2af5

Browse files
aarondavmateiz
authored andcommitted
SPARK-2791: Fix committing, reverting and state tracking in shuffle file consolidation
All changes from this PR are by mridulm and are drawn from his work in #1609. This patch is intended to fix all major issues related to shuffle file consolidation that mridulm found, while minimizing changes to the code, with the hope that it may be more easily merged into 1.1. This patch is **not** intended as a replacement for #1609, which provides many additional benefits, including fixes to ExternalAppendOnlyMap, improvements to DiskBlockObjectWriter's API, and several new unit tests. If it is feasible to merge #1609 for the 1.1 deadline, that is a preferable option. Author: Aaron Davidson <aaron@databricks.com> Closes #1678 from aarondav/consol and squashes the following commits: 53b3f6d [Aaron Davidson] Correct behavior when writing unopened file 701d045 [Aaron Davidson] Rebase with sort-based shuffle 9160149 [Aaron Davidson] SPARK-2532: Minimal shuffle consolidation fixes
1 parent b270309 commit 78f2af5

File tree

8 files changed

+146
-52
lines changed

8 files changed

+146
-52
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,25 @@ private[spark] class HashShuffleWriter[K, V](
6565
}
6666

6767
/** Close this writer, passing along whether the map completed */
68-
override def stop(success: Boolean): Option[MapStatus] = {
68+
override def stop(initiallySuccess: Boolean): Option[MapStatus] = {
69+
var success = initiallySuccess
6970
try {
7071
if (stopping) {
7172
return None
7273
}
7374
stopping = true
7475
if (success) {
7576
try {
76-
return Some(commitWritesAndBuildStatus())
77+
Some(commitWritesAndBuildStatus())
7778
} catch {
7879
case e: Exception =>
80+
success = false
7981
revertWrites()
8082
throw e
8183
}
8284
} else {
8385
revertWrites()
84-
return None
86+
None
8587
}
8688
} finally {
8789
// Release the writers back to the shuffle block manager.
@@ -100,8 +102,7 @@ private[spark] class HashShuffleWriter[K, V](
100102
var totalBytes = 0L
101103
var totalTime = 0L
102104
val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
103-
writer.commit()
104-
writer.close()
105+
writer.commitAndClose()
105106
val size = writer.fileSegment().length
106107
totalBytes += size
107108
totalTime += writer.timeWriting()
@@ -120,8 +121,7 @@ private[spark] class HashShuffleWriter[K, V](
120121
private def revertWrites(): Unit = {
121122
if (shuffle != null && shuffle.writers != null) {
122123
for (writer <- shuffle.writers) {
123-
writer.revertPartialWrites()
124-
writer.close()
124+
writer.revertPartialWritesAndClose()
125125
}
126126
}
127127
}

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ private[spark] class SortShuffleWriter[K, V, C](
9494
for (elem <- elements) {
9595
writer.write(elem)
9696
}
97-
writer.commit()
98-
writer.close()
97+
writer.commitAndClose()
9998
val segment = writer.fileSegment()
10099
offsets(id + 1) = segment.offset + segment.length
101100
lengths(id) = segment.length

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
3939
def isOpen: Boolean
4040

4141
/**
42-
* Flush the partial writes and commit them as a single atomic block. Return the
43-
* number of bytes written for this commit.
42+
* Flush the partial writes and commit them as a single atomic block.
4443
*/
45-
def commit(): Long
44+
def commitAndClose(): Unit
4645

4746
/**
4847
* Reverts writes that haven't been flushed yet. Callers should invoke this function
49-
* when there are runtime exceptions.
48+
* when there are runtime exceptions. This method will not throw, though it may be
49+
* unsuccessful in truncating written data.
5050
*/
51-
def revertPartialWrites()
51+
def revertPartialWritesAndClose()
5252

5353
/**
5454
* Writes an object.
@@ -57,6 +57,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
5757

5858
/**
5959
* Returns the file segment of committed data that this Writer has written.
60+
* This is only valid after commitAndClose() has been called.
6061
*/
6162
def fileSegment(): FileSegment
6263

@@ -108,15 +109,14 @@ private[spark] class DiskBlockObjectWriter(
108109
private var ts: TimeTrackingOutputStream = null
109110
private var objOut: SerializationStream = null
110111
private val initialPosition = file.length()
111-
private var lastValidPosition = initialPosition
112+
private var finalPosition: Long = -1
112113
private var initialized = false
113114
private var _timeWriting = 0L
114115

115116
override def open(): BlockObjectWriter = {
116117
fos = new FileOutputStream(file, true)
117118
ts = new TimeTrackingOutputStream(fos)
118119
channel = fos.getChannel()
119-
lastValidPosition = initialPosition
120120
bs = compressStream(new BufferedOutputStream(ts, bufferSize))
121121
objOut = serializer.newInstance().serializeStream(bs)
122122
initialized = true
@@ -147,28 +147,36 @@ private[spark] class DiskBlockObjectWriter(
147147

148148
override def isOpen: Boolean = objOut != null
149149

150-
override def commit(): Long = {
150+
override def commitAndClose(): Unit = {
151151
if (initialized) {
152152
// NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the
153153
// serializer stream and the lower level stream.
154154
objOut.flush()
155155
bs.flush()
156-
val prevPos = lastValidPosition
157-
lastValidPosition = channel.position()
158-
lastValidPosition - prevPos
159-
} else {
160-
// lastValidPosition is zero if stream is uninitialized
161-
lastValidPosition
156+
close()
162157
}
158+
finalPosition = file.length()
163159
}
164160

165-
override def revertPartialWrites() {
166-
if (initialized) {
167-
// Discard current writes. We do this by flushing the outstanding writes and
168-
// truncate the file to the last valid position.
169-
objOut.flush()
170-
bs.flush()
171-
channel.truncate(lastValidPosition)
161+
// Discard current writes. We do this by flushing the outstanding writes and then
162+
// truncating the file to its initial position.
163+
override def revertPartialWritesAndClose() {
164+
try {
165+
if (initialized) {
166+
objOut.flush()
167+
bs.flush()
168+
close()
169+
}
170+
171+
val truncateStream = new FileOutputStream(file, true)
172+
try {
173+
truncateStream.getChannel.truncate(initialPosition)
174+
} finally {
175+
truncateStream.close()
176+
}
177+
} catch {
178+
case e: Exception =>
179+
logError("Uncaught exception while reverting partial writes to file " + file, e)
172180
}
173181
}
174182

@@ -188,6 +196,7 @@ private[spark] class DiskBlockObjectWriter(
188196

189197
// Only valid if called after commit()
190198
override def bytesWritten: Long = {
191-
lastValidPosition - initialPosition
199+
assert(finalPosition != -1, "bytesWritten is only valid after successful commit()")
200+
finalPosition - initialPosition
192201
}
193202
}

core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
144144
if (consolidateShuffleFiles) {
145145
if (success) {
146146
val offsets = writers.map(_.fileSegment().offset)
147-
fileGroup.recordMapOutput(mapId, offsets)
147+
val lengths = writers.map(_.fileSegment().length)
148+
fileGroup.recordMapOutput(mapId, offsets, lengths)
148149
}
149150
recycleFileGroup(fileGroup)
150151
} else {
@@ -247,47 +248,48 @@ object ShuffleBlockManager {
247248
* A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
248249
*/
249250
private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
251+
private var numBlocks: Int = 0
252+
250253
/**
251254
* Stores the absolute index of each mapId in the files of this group. For instance,
252255
* if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
253256
*/
254257
private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
255258

256259
/**
257-
* Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
258-
* This ordering allows us to compute block lengths by examining the following block offset.
260+
* Stores consecutive offsets and lengths of blocks into each reducer file, ordered by
261+
* position in the file.
259262
* Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
260263
* reducer.
261264
*/
262265
private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
263266
new PrimitiveVector[Long]()
264267
}
265-
266-
def numBlocks = mapIdToIndex.size
268+
private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
269+
new PrimitiveVector[Long]()
270+
}
267271

268272
def apply(bucketId: Int) = files(bucketId)
269273

270-
def recordMapOutput(mapId: Int, offsets: Array[Long]) {
274+
def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) {
275+
assert(offsets.length == lengths.length)
271276
mapIdToIndex(mapId) = numBlocks
277+
numBlocks += 1
272278
for (i <- 0 until offsets.length) {
273279
blockOffsetsByReducer(i) += offsets(i)
280+
blockLengthsByReducer(i) += lengths(i)
274281
}
275282
}
276283

277284
/** Returns the FileSegment associated with the given map task, or None if no entry exists. */
278285
def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
279286
val file = files(reducerId)
280287
val blockOffsets = blockOffsetsByReducer(reducerId)
288+
val blockLengths = blockLengthsByReducer(reducerId)
281289
val index = mapIdToIndex.getOrElse(mapId, -1)
282290
if (index >= 0) {
283291
val offset = blockOffsets(index)
284-
val length =
285-
if (index + 1 < numBlocks) {
286-
blockOffsets(index + 1) - offset
287-
} else {
288-
file.length() - offset
289-
}
290-
assert(length >= 0)
292+
val length = blockLengths(index)
291293
Some(new FileSegment(file, offset, length))
292294
} else {
293295
None

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class ExternalAppendOnlyMap[K, V, C](
199199

200200
// Flush the disk writer's contents to disk, and update relevant variables
201201
def flush() = {
202-
writer.commit()
202+
writer.commitAndClose()
203203
val bytesWritten = writer.bytesWritten
204204
batchSizes.append(bytesWritten)
205205
_diskBytesSpilled += bytesWritten

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,10 @@ private[spark] class ExternalSorter[K, V, C](
270270
// How many elements we have in each partition
271271
val elementsPerPartition = new Array[Long](numPartitions)
272272

273-
// Flush the disk writer's contents to disk, and update relevant variables
273+
// Flush the disk writer's contents to disk, and update relevant variables.
274+
// The writer is closed at the end of this process, and cannot be reused.
274275
def flush() = {
275-
writer.commit()
276+
writer.commitAndClose()
276277
val bytesWritten = writer.bytesWritten
277278
batchSizes.append(bytesWritten)
278279
_diskBytesSpilled += bytesWritten
@@ -293,7 +294,6 @@ private[spark] class ExternalSorter[K, V, C](
293294

294295
if (objectsWritten == serializerBatchSize) {
295296
flush()
296-
writer.close()
297297
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
298298
}
299299
}

core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@ import java.io.{File, FileWriter}
2222
import scala.collection.mutable
2323
import scala.language.reflectiveCalls
2424

25+
import akka.actor.Props
2526
import com.google.common.io.Files
2627
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
2728

2829
import org.apache.spark.SparkConf
29-
import org.apache.spark.util.Utils
30+
import org.apache.spark.scheduler.LiveListenerBus
31+
import org.apache.spark.serializer.JavaSerializer
32+
import org.apache.spark.util.{AkkaUtils, Utils}
3033

3134
class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
3235
private val testConf = new SparkConf(false)
@@ -121,6 +124,88 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
121124
newFile.delete()
122125
}
123126

127+
private def checkSegments(segment1: FileSegment, segment2: FileSegment) {
128+
assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath)
129+
assert (segment1.offset === segment2.offset)
130+
assert (segment1.length === segment2.length)
131+
}
132+
133+
test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
134+
135+
val serializer = new JavaSerializer(testConf)
136+
val confCopy = testConf.clone
137+
// reset after EACH object write. This is to ensure that there are bytes appended after
138+
// an object is written. So if the codepaths assume writeObject is end of data, this should
139+
// flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc.
140+
confCopy.set("spark.serializer.objectStreamReset", "1")
141+
142+
val securityManager = new org.apache.spark.SecurityManager(confCopy)
143+
// Do not use the shuffleBlockManager above !
144+
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, confCopy,
145+
securityManager)
146+
val master = new BlockManagerMaster(
147+
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))),
148+
confCopy)
149+
val store = new BlockManager("<driver>", actorSystem, master , serializer, confCopy,
150+
securityManager, null)
151+
152+
try {
153+
154+
val shuffleManager = store.shuffleBlockManager
155+
156+
val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer)
157+
for (writer <- shuffle1.writers) {
158+
writer.write("test1")
159+
writer.write("test2")
160+
}
161+
for (writer <- shuffle1.writers) {
162+
writer.commitAndClose()
163+
}
164+
165+
val shuffle1Segment = shuffle1.writers(0).fileSegment()
166+
shuffle1.releaseWriters(success = true)
167+
168+
val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf))
169+
170+
for (writer <- shuffle2.writers) {
171+
writer.write("test3")
172+
writer.write("test4")
173+
}
174+
for (writer <- shuffle2.writers) {
175+
writer.commitAndClose()
176+
}
177+
val shuffle2Segment = shuffle2.writers(0).fileSegment()
178+
shuffle2.releaseWriters(success = true)
179+
180+
// Now comes the test :
181+
// Write to shuffle 3; and close it, but before registering it, check if the file lengths for
182+
// previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length
183+
// of block based on remaining data in file : which could mess things up when there is concurrent read
184+
// and writes happening to the same shuffle group.
185+
186+
val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf))
187+
for (writer <- shuffle3.writers) {
188+
writer.write("test3")
189+
writer.write("test4")
190+
}
191+
for (writer <- shuffle3.writers) {
192+
writer.commitAndClose()
193+
}
194+
// check before we register.
195+
checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0)))
196+
shuffle3.releaseWriters(success = true)
197+
checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0)))
198+
shuffleManager.removeShuffle(1)
199+
} finally {
200+
201+
if (store != null) {
202+
store.stop()
203+
}
204+
actorSystem.shutdown()
205+
actorSystem.awaitTermination()
206+
}
207+
}
208+
124209
def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
125210
val segment = diskBlockManager.getBlockLocation(blockId)
126211
assert(segment.file.getName === filename)

tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,9 @@ object StoragePerfTester {
6161
for (i <- 1 to recordsPerMap) {
6262
writers(i % numOutputSplits).write(writeData)
6363
}
64-
writers.map {w =>
65-
w.commit()
64+
writers.map { w =>
65+
w.commitAndClose()
6666
total.addAndGet(w.fileSegment().length)
67-
w.close()
6867
}
6968

7069
shuffle.releaseWriters(true)

0 commit comments

Comments
 (0)