diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 9b7f901c55e00..bbf21e0b803b1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -64,6 +64,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { case directResult: DirectTaskResult[_] => if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { + scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled( + "Tasks result size has exceeded maxResultSize")) return } // deserialize "value" without holding any lock so that it won't block other threads. @@ -75,6 +77,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul if (!taskSetManager.canFetchMoreResults(size)) { // dropped by executor if size is larger than maxResultSize sparkEnv.blockManager.master.removeBlock(blockId) + scheduler.handleFailedTask(taskSetManager, tid, TaskState.KILLED, TaskKilled( + "Tasks result size has exceeded maxResultSize")) return } logDebug("Fetching indirect task result for TID %s".format(tid)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index ae464352da440..a8f975609134f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.TaskState.TaskState import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.internal.config.Network.RPC_MESSAGE_MAX_SIZE import org.apache.spark.storage.TaskResultBlockId @@ -78,6 +79,16 @@ private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: Task } } +private class DummyTaskSchedulerImpl(sc: SparkContext) + extends TaskSchedulerImpl(sc, 1, true) { + override def handleFailedTask( + taskSetManager: TaskSetManager, + tid: Long, + taskState: TaskState, + reason: TaskFailedReason): Unit = { + // do nothing + } +} /** * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors @@ -130,6 +141,29 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local "Expect result to be removed from the block manager.") } + test("handling total size of results larger than maxResultSize") { + sc = new SparkContext("local", "test", conf) + val scheduler = new DummyTaskSchedulerImpl(sc) + val spyScheduler = spy(scheduler) + val resultGetter = new TaskResultGetter(sc.env, spyScheduler) + spyScheduler.taskResultGetter = resultGetter + val myTsm = new TaskSetManager(spyScheduler, FakeTask.createTaskSet(2), 1) { + // always returns false + override def canFetchMoreResults(size: Long): Boolean = false + } + val indirectTaskResult = IndirectTaskResult(TaskResultBlockId(0), 0) + val directTaskResult = new DirectTaskResult(ByteBuffer.allocate(0), Nil, Array()) + val ser = sc.env.closureSerializer.newInstance() + val serializedIndirect = ser.serialize(indirectTaskResult) + val serializedDirect = ser.serialize(directTaskResult) + resultGetter.enqueueSuccessfulTask(myTsm, 0, serializedDirect) + resultGetter.enqueueSuccessfulTask(myTsm, 1, serializedIndirect) + verify(spyScheduler, times(1)).handleFailedTask( + myTsm, 0, TaskState.KILLED, TaskKilled("Tasks result size has exceeded maxResultSize")) + verify(spyScheduler, times(1)).handleFailedTask( + myTsm, 1, TaskState.KILLED, TaskKilled("Tasks result size has exceeded maxResultSize")) + } + test("task retried if result missing from block manager") { // Set the maximum number of task failures to > 0, so that the task set isn't aborted // after the result is missing.