/
ExecutorSuite.scala
422 lines (382 loc) · 17.6 KB
/
ExecutorSuite.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.executor
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable.Map
import scala.concurrent.duration._
import scala.language.postfixOps
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
import org.scalatest.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.memory.MemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.UninterruptibleThread
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
// mock some objects to make Executor.launchTask() happy
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
val taskDescription = createFakeTaskDescription(serializedTask)
// we use latches to force the program to run in this order:
// +-----------------------------+---------------------------------------+
// | main test thread | worker thread |
// +-----------------------------+---------------------------------------+
// | executor.launchTask() | |
// | | TaskRunner.run() begins |
// | | ... |
// | | execBackend.statusUpdate // 1st time |
// | executor.killAllTasks(true) | |
// | | ... |
// | | task = ser.deserialize |
// | | ... |
// | | execBackend.statusUpdate // 2nd time |
// | | ... |
// | | TaskRunner.run() ends |
// | check results | |
// +-----------------------------+---------------------------------------+
val executorSuiteHelper = new ExecutorSuiteHelper
val mockExecutorBackend = mock[ExecutorBackend]
when(mockExecutorBackend.statusUpdate(any(), any(), any()))
.thenAnswer(new Answer[Unit] {
var firstTime = true
override def answer(invocationOnMock: InvocationOnMock): Unit = {
if (firstTime) {
executorSuiteHelper.latch1.countDown()
// here between latch1 and latch2, executor.killAllTasks() is called
executorSuiteHelper.latch2.await()
firstTime = false
}
else {
// save the returned `taskState` and `testFailedReason` into `executorSuiteHelper`
val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState]
executorSuiteHelper.taskState = taskState
val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer]
executorSuiteHelper.testFailedReason =
serializer.newInstance().deserialize(taskEndReason)
// let the main test thread check `taskState` and `testFailedReason`
executorSuiteHelper.latch3.countDown()
}
}
})
var executor: Executor = null
try {
executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockExecutorBackend, taskDescription)
if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) {
fail("executor did not send first status update in time")
}
// we know the task will be started, but not yet deserialized, because of the latches we
// use in mockExecutorBackend.
executor.killAllTasks(true, "test")
executorSuiteHelper.latch2.countDown()
if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) {
fail("executor did not send second status update in time")
}
// `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
assert(executorSuiteHelper.testFailedReason === TaskKilled("test"))
assert(executorSuiteHelper.taskState === TaskState.KILLED)
}
finally {
if (executor != null) {
executor.stop()
}
}
}
test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") {
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
// Submit a job where a fetch failure is thrown, but user code has a try/catch which hides
// the fetch failure. The executor should still tell the driver that the task failed due to a
// fetch failure, not a generic exception from user code.
val inputRDD = new FetchFailureThrowingRDD(sc)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, interrupt = false)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
stageId = 1,
stageAttemptId = 0,
taskBinary = taskBinary,
partition = secondRDD.partitions(0),
locs = Seq(),
outputId = 0,
localProperties = new Properties(),
serializedTaskMetrics = serializedTaskMetrics
)
val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)
val failReason = runTaskAndGetFailReason(taskDescription)
assert(failReason.isInstanceOf[FetchFailed])
}
test("Executor's worker threads should be UninterruptibleThread") {
val conf = new SparkConf()
.setMaster("local")
.setAppName("executor thread test")
.set("spark.ui.enabled", "false")
sc = new SparkContext(conf)
val executorThread = sc.parallelize(Seq(1), 1).map { _ =>
Thread.currentThread.getClass.getName
}.collect().head
assert(executorThread === classOf[UninterruptibleThread].getName)
}
test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true)
assert(failReason.isInstanceOf[ExceptionFailure])
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
assert(exceptionCaptor.getAllValues.size === 1)
assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError])
}
test("SPARK-23816: interrupts are not masked by a FetchFailure") {
// If killing the task causes a fetch failure, we still treat it as a task that was killed,
// as the fetch failure could easily be caused by interrupting the thread.
val (failReason, _) = testFetchFailureHandling(false)
assert(failReason.isInstanceOf[TaskKilled])
}
/**
* Helper for testing some cases where a FetchFailure should *not* get sent back, because its
* superceded by another error, either an OOM or intentionally killing a task.
* @param oom if true, throw an OOM after the FetchFailure; else, interrupt the task after the
* FetchFailure
*/
private def testFetchFailureHandling(
oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
// may be a false positive. And we should call the uncaught exception handler.
// SPARK-23816 also handle interrupts the same way, as killing an obsolete speculative task
// does not represent a real fetch failure.
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
// Submit a job where a fetch failure is thrown, but then there is an OOM or interrupt. We
// should treat the fetch failure as a false positive, and do normal OOM or interrupt handling.
val inputRDD = new FetchFailureThrowingRDD(sc)
if (!oom) {
// we are trying to setup a case where a task is killed after a fetch failure -- this
// is just a helper to coordinate between the task thread and this thread that will
// kill the task
ExecutorSuiteHelper.latches = new ExecutorSuiteHelper()
}
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, interrupt = !oom)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
stageId = 1,
stageAttemptId = 0,
taskBinary = taskBinary,
partition = secondRDD.partitions(0),
locs = Seq(),
outputId = 0,
localProperties = new Properties(),
serializedTaskMetrics = serializedTaskMetrics
)
val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)
runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom)
}
test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask)
val taskDescription = createFakeTaskDescription(serializedTask)
val failReason = runTaskAndGetFailReason(taskDescription)
failReason match {
case ef: ExceptionFailure =>
assert(ef.exception.isDefined)
assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg)
case _ =>
fail(s"unexpected failure type: $failReason")
}
}
private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
val mockMetricsSystem = mock[MetricsSystem]
val mockMemoryManager = mock[MemoryManager]
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.serializerManager).thenReturn(mock[SerializerManager])
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
when(mockEnv.closureSerializer).thenReturn(serializer)
SparkEnv.set(mockEnv)
mockEnv
}
private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = {
new TaskDescription(
taskId = 0,
attemptNumber = 0,
executorId = "",
name = "",
index = 0,
addedFiles = Map[String, Long](),
addedJars = Map[String, Long](),
properties = new Properties,
serializedTask)
}
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1
}
private def runTaskGetFailReasonAndExceptionHandler(
taskDescription: TaskDescription,
killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
val timedOut = new AtomicBoolean(false)
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
if (killTask) {
val killingThread = new Thread("kill-task") {
override def run(): Unit = {
// wait to kill the task until it has thrown a fetch failure
if (ExecutorSuiteHelper.latches.latch1.await(10, TimeUnit.SECONDS)) {
// now we can kill the task
executor.killAllTasks(true, "Killed task, eg. because of speculative execution")
} else {
timedOut.set(true)
}
}
}
killingThread.start()
}
eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
assert(!timedOut.get(), "timed out waiting to be ready to kill tasks")
} finally {
if (executor != null) {
executor.stop()
}
}
val orderedMock = inOrder(mockBackend)
val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED
orderedMock.verify(mockBackend)
.statusUpdate(meq(0L), meq(finalState), statusCaptor.capture())
// first statusUpdate for RUNNING has empty data
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
val failReason =
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
(failReason, mockUncaughtExceptionHandler)
}
}
class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
new Iterator[Int] {
override def hasNext: Boolean = true
override def next(): Int = {
throw new FetchFailedException(
bmAddress = BlockManagerId("1", "hostA", 1234),
shuffleId = 0,
mapId = 0,
reduceId = 0,
message = "fake fetch failure"
)
}
}
}
override protected def getPartitions: Array[Partition] = {
Array(new SimplePartition)
}
}
class SimplePartition extends Partition {
override def index: Int = 0
}
class FetchFailureHidingRDD(
sc: SparkContext,
val input: FetchFailureThrowingRDD,
throwOOM: Boolean,
interrupt: Boolean) extends RDD[Int](input) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
val inItr = input.compute(split, context)
try {
Iterator(inItr.size)
} catch {
case t: Throwable =>
if (throwOOM) {
throw new OutOfMemoryError("OOM while handling another exception")
} else if (interrupt) {
// make sure our test is setup correctly
assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined)
// signal our test is ready for the task to get killed
ExecutorSuiteHelper.latches.latch1.countDown()
// then wait for another thread in the test to kill the task -- this latch
// is never actually decremented, we just wait to get killed.
ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS)
throw new IllegalStateException("timed out waiting to be interrupted")
} else {
throw new RuntimeException("User Exception that hides the original exception", t)
}
}
}
override protected def getPartitions: Array[Partition] = {
Array(new SimplePartition)
}
}
// Helps to test("SPARK-15963")
private class ExecutorSuiteHelper {
val latch1 = new CountDownLatch(1)
val latch2 = new CountDownLatch(1)
val latch3 = new CountDownLatch(1)
@volatile var taskState: TaskState = _
@volatile var testFailedReason: TaskFailedReason = _
}
// helper for coordinating killing tasks
private object ExecutorSuiteHelper {
var latches: ExecutorSuiteHelper = null
}
private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
def writeExternal(out: ObjectOutput): Unit = {}
def readExternal(in: ObjectInput): Unit = {
throw new RuntimeException(NonDeserializableTask.errorMsg)
}
}
private object NonDeserializableTask {
val errorMsg = "failure in deserialization"
}