Skip to content

Commit

Permalink
Created a monotask to serialize macrotask results.
Browse files Browse the repository at this point in the history
Created a new ComputeMonotask called ResultSerializationMonotask that
serializes the macrotask result generated by an ExecutionMonotask. There
is exactly one ResultSerializationMonotask per macrotask, and it depends
on all of the "leaves" of the DAG of monotasks (this means that it is
always the only sink in the DAG of monotasks, and that it will be the
last monotask to run).

The reason for this change is that we need to refrain from calling
TaskMetrics.setMetricsOnTaskCompletion() until all of a macrotask's
main compute/disk/network monotasks have finished executing, otherwise
the TaskMetrics will not accurately reflect the resource usage over the
entire time that the macrotask was executing.

See issue apache#22.
  • Loading branch information
Christopher Canel committed Jun 8, 2015
1 parent cb7d5ea commit 44357d7
Show file tree
Hide file tree
Showing 14 changed files with 237 additions and 205 deletions.
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.monotasks

import java.nio.ByteBuffer

import scala.collection.mutable.{HashMap, HashSet}
import scala.collection.mutable.HashSet

import org.apache.spark.{Logging, TaskState}
import org.apache.spark.executor.ExecutorBackend
Expand Down Expand Up @@ -53,24 +53,17 @@ private[spark] class LocalDagScheduler(
* debugging/testing and is not needed for maintaining correctness. */
val runningMonotasks = new HashSet[Long]()

/* Maps macrotask attempt ID to the IDs of monotasks that are part of that macrotask but have not
* finished executing yet. A macrotask is not finished until its set is empty and it has a result.
* This is also used to determine whether to notify the executor backend that a task has failed
* (used to avoid duplicate failure messages if multiple monotasks for the macrotask fail). */
val macrotaskRemainingMonotasks = new HashMap[Long, HashSet[Long]]()

/* Maps macrotask attempt ID to that macrotask's serialized task result. This gives the
* LocalDagScheduler a way to store macrotask result buffers in the event that the monotask that
* creates the result is not the last monotask to execute for that macrotask (the macrotask cannot
* return its result until all of its monotasks have finished). */
val macrotaskResults = new HashMap[Long, ByteBuffer]()
/* IDs for macrotasks that currently are running. Used to determine whether to notify the
* executor backend that a task has failed (used to avoid duplicate failure messages if multiple
* monotasks for the macrotask fail). */
val runningMacrotaskAttemptIds = new HashSet[Long]()

def getNumRunningComputeMonotasks(): Int = {
computeScheduler.numRunningTasks.get()
}

def getNumRunningMacrotasks(): Int = {
macrotaskRemainingMonotasks.keySet.size
runningMacrotaskAttemptIds.size
}

def getOutstandingNetworkBytes(): Long = networkScheduler.getOutstandingBytes
Expand All @@ -83,8 +76,7 @@ private[spark] class LocalDagScheduler(
}
val taskAttemptId = monotask.context.taskAttemptId
logDebug(s"Submitting monotask $monotask (id: ${monotask.taskId}) for macrotask $taskAttemptId")
macrotaskRemainingMonotasks.getOrElseUpdate(taskAttemptId, new HashSet[Long]()) +=
monotask.taskId
runningMacrotaskAttemptIds += taskAttemptId
}

/** It is assumed that all monotasks for a specific macrotask are submitted at the same time. */
Expand All @@ -94,7 +86,7 @@ private[spark] class LocalDagScheduler(

/**
* Marks the monotask as successfully completed by updating the dependency tree and running any
* newly runnable monotasks.
* newly-runnable monotasks.
*
* @param completedMonotask The monotask that has completed.
* @param serializedTaskResult If the monotask was the final monotask for the macrotask, a
Expand All @@ -106,9 +98,9 @@ private[spark] class LocalDagScheduler(
val taskAttemptId = completedMonotask.context.taskAttemptId
logDebug(s"Monotask $completedMonotask (id: ${completedMonotask.taskId}) for " +
s"macrotask $taskAttemptId has completed.")
runningMonotasks.remove(completedMonotask.taskId)

if (macrotaskRemainingMonotasks.contains(taskAttemptId)) {
if (runningMacrotaskAttemptIds.contains(taskAttemptId)) {
// If the macrotask has not failed, schedule any newly-ready monotasks.
completedMonotask.dependents.foreach { monotask =>
monotask.dependencies -= completedMonotask.taskId
if (monotask.dependencies.isEmpty) {
Expand All @@ -119,33 +111,22 @@ private[spark] class LocalDagScheduler(
}
}

if ((macrotaskRemainingMonotasks(taskAttemptId) -= completedMonotask.taskId).isEmpty) {
// All monotasks for this macrotask have completed, so send the result to the
// executorBackend.
serializedTaskResult.orElse(macrotaskResults.get(taskAttemptId)).map { result =>
completedMonotask.context.markTaskCompleted()
logDebug(s"Notfiying executorBackend about successful completion of task $taskAttemptId")
executorBackend.statusUpdate(taskAttemptId, TaskState.FINISHED, result)

macrotaskRemainingMonotasks -= taskAttemptId
macrotaskResults -= taskAttemptId
}.getOrElse{
logError(s"Macrotask $taskAttemptId does not have a result even though all of its " +
"monotasks have completed.")
}
} else {
// If we received a result, store it so it can be passed to the executorBackend once all of
// the monotasks for this macrotask have completed.
serializedTaskResult.foreach(macrotaskResults(taskAttemptId) = _)
serializedTaskResult.map { result =>
// Tell the executorBackend that the macrotask finished.
runningMacrotaskAttemptIds.remove(taskAttemptId)
completedMonotask.context.markTaskCompleted()
logDebug(s"Notfiying executorBackend about successful completion of task $taskAttemptId")
executorBackend.statusUpdate(taskAttemptId, TaskState.FINISHED, result)
}
} else {
// Another monotask in this macrotask must have failed while completedMonotask was running,
// causing the macrotask to fail and its taskAttemptId to be removed from
// macrotaskRemainingMonotasks. We should fail completedMonotask's dependents in case they
// have not been failed already, which can happen if they are not dependents of the monotask
// that failed.
// This will only happen if another monotask in this macrotask failed while completedMonotask
// was running, causing the macrotask to fail and its taskAttemptId to be removed from
// runningMacrotaskAttemptIds. We should fail completedMonotask's dependents in case they have
// not been failed already, which can happen if they are not dependents of the monotask that
// failed.
failDependentMonotasks(completedMonotask)
}
runningMonotasks.remove(completedMonotask.taskId)
}

/**
Expand All @@ -162,13 +143,12 @@ private[spark] class LocalDagScheduler(
runningMonotasks -= failedMonotask.taskId
failDependentMonotasks(failedMonotask, Some(failedMonotask.taskId))
val taskAttemptId = failedMonotask.context.taskAttemptId

// Notify the executor backend that the macrotask has failed, if we didn't already.
if (macrotaskRemainingMonotasks.remove(taskAttemptId).isDefined) {
if (runningMacrotaskAttemptIds.remove(taskAttemptId)) {
failedMonotask.context.markTaskCompleted()
executorBackend.statusUpdate(taskAttemptId, TaskState.FAILED, serializedFailureReason)
}

macrotaskResults.remove(taskAttemptId)
}

private def failDependentMonotasks(
Expand All @@ -177,7 +157,7 @@ private[spark] class LocalDagScheduler(
// TODO: We don't interrupt monotasks that are already running. See
// https://github.com/NetSys/spark-monotasks/issues/10
val message = originalFailedTaskId.map { taskId =>
s"it dependend on monotask $taskId, which failed"
s"it depended on monotask $taskId, which failed"
}.getOrElse(s"another monotask in macrotask ${monotask.context.taskAttemptId} failed")

monotask.dependents.foreach { dependentMonotask =>
Expand Down
Expand Up @@ -85,8 +85,8 @@ private[spark] abstract class ComputeMonotask(context: TaskContextImpl)
/**
* Adds the time taken by this monotask to the macrotask's TaskMetrics, if it hasn't already been
* done for this monotask. This method needs to check whether it was already called because
* ExecutionMonotasks need to call this method themselves, before serializing the task result
* (otherwise the update to the metrics won't be reflected in the serialized TaskMetrics
* ResultSerializationMonotasks need to call this method themselves, before serializing the task
* result (otherwise the update to the metrics won't be reflected in the serialized TaskMetrics
* that are sent back to the driver).
*/
protected def accountForComputeTime() {
Expand Down
Expand Up @@ -20,58 +20,29 @@ import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.{Accumulators, Logging, Partition, TaskContextImpl}
import org.apache.spark.{Partition, TaskContextImpl}
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult}
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.storage.{MonotaskResultBlockId, StorageLevel}

/**
* Monotask that handles executing the core computation of a macro task and serializing the result.
* Monotask that handles executing the core computation of a macrotask. The result is stored in
* memory.
*/
private[spark] abstract class ExecutionMonotask[T, U: ClassTag](
context: TaskContextImpl,
val rdd: RDD[T],
val split: Partition)
extends ComputeMonotask(context) with Logging {
extends ComputeMonotask(context) {

// BlockId used to store this monotask's result in the BlockManager.
val resultBlockId = new MonotaskResultBlockId(taskId)

/** Subclasses should define this to return a macrotask result to be sent to the driver. */
def getResult(): U

override protected def execute(): Option[ByteBuffer] = {
val result = getResult()
context.markTaskCompleted()
val serializedResult = serializeResult(result)
Some(serializedResult)
}

private def serializeResult(result: U): ByteBuffer = {
// The mysterious choice of which serializer to use when is written to be consistent with Spark.
val closureSerializer = context.env.closureSerializer.newInstance()
val resultSer = context.env.serializer.newInstance()

val serializationStartTime = System.currentTimeMillis()
val valueBytes = resultSer.serialize(result)
context.taskMetrics.setResultSerializationTime(
System.currentTimeMillis() - serializationStartTime)
accountForComputeTime()

context.taskMetrics.setMetricsOnTaskCompletion()
val accumulatorValues = Accumulators.getValues
val directResult = new DirectTaskResult(valueBytes, accumulatorValues, context.taskMetrics)
val serializedDirectResult = closureSerializer.serialize(directResult)
val resultSize = serializedDirectResult.limit

if (context.maximumResultSizeBytes > 0 && resultSize > context.maximumResultSizeBytes) {
val blockId = TaskResultBlockId(context.taskAttemptId)
context.localDagScheduler.blockManager.cacheBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
logInfo(s"Finished TID ${context.taskAttemptId}. $resultSize bytes result will be sent " +
"via the BlockManager)")
closureSerializer.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished TID ${context.taskAttemptId}. $resultSize bytes result will be sent " +
"directly to driver")
serializedDirectResult
}
context.localDagScheduler.blockManager.cacheSingle(
resultBlockId, getResult(), StorageLevel.MEMORY_ONLY, false)
None
}
}
@@ -0,0 +1,76 @@
/*
* Copyright 2014 The Regents of The University California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.monotasks.compute

import java.nio.ByteBuffer

import org.apache.spark.{Accumulators, Logging, TaskContextImpl}
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult}
import org.apache.spark.storage.{BlockId, StorageLevel, TaskResultBlockId}

/**
* ResultSerializationMonotasks are responsible for serializing the result of a macrotask and the
* associated metrics. The DAG for a macrotask always contains exactly one
* ResultSerializationMonotask, and it is run after all of the macrotask's other monotasks have
* completed (because otherwise the metrics computed by ResultSerializationMonotask would not be
* complete).
*/
class ResultSerializationMonotask(context: TaskContextImpl, resultBlockId: BlockId)
extends ComputeMonotask(context) with Logging {

override def execute(): Option[ByteBuffer] = {
val blockManager = context.localDagScheduler.blockManager
blockManager.getSingle(resultBlockId).map { result =>
blockManager.removeBlockFromMemory(resultBlockId, false)
context.markTaskCompleted()

// The mysterious choice of which serializer to use when is written to be consistent with
// Spark.
val closureSerializer = context.env.closureSerializer.newInstance()
val resultSer = context.env.serializer.newInstance()

val serializationStartTime = System.currentTimeMillis()
val valueBytes = resultSer.serialize(result)
context.taskMetrics.setResultSerializationTime(
System.currentTimeMillis() - serializationStartTime)
accountForComputeTime()

context.taskMetrics.setMetricsOnTaskCompletion()
val accumulatorValues = Accumulators.getValues
val directResult = new DirectTaskResult(valueBytes, accumulatorValues, context.taskMetrics)
val serializedDirectResult = closureSerializer.serialize(directResult)
val resultSize = serializedDirectResult.limit

if (context.maximumResultSizeBytes > 0 && resultSize > context.maximumResultSizeBytes) {
val blockId = TaskResultBlockId(context.taskAttemptId)
context.localDagScheduler.blockManager.cacheBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
logInfo(s"Finished TID ${context.taskAttemptId}. $resultSize bytes result will be sent " +
"via the BlockManager.")
closureSerializer.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished TID ${context.taskAttemptId}. $resultSize bytes result will be sent " +
"directly to driver.")
serializedDirectResult
}
}.orElse {
throw new IllegalStateException(s"Deserialized result for macrotask " +
s"${context.taskAttemptId} could not be found in the BlockManager " +
s"using blockId $resultBlockId.")
}
}
}
28 changes: 25 additions & 3 deletions core/src/main/scala/org/apache/spark/scheduler/Macrotask.scala
Expand Up @@ -20,9 +20,12 @@ import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.nio.ByteBuffer

import scala.collection.mutable.{HashMap, HashSet}
import scala.language.existentials

import org.apache.spark.{Logging, Partition, TaskContextImpl}
import org.apache.spark.monotasks.Monotask
import org.apache.spark.monotasks.compute.{ExecutionMonotask, ResultSerializationMonotask}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream

Expand All @@ -49,10 +52,29 @@ private[spark] abstract class Macrotask[T](val stageId: Int, val partition: Part
var epoch: Long = -1

/**
* Returns the monotasks that need to be run in order to execute this macrotask. This is run
* within a compute monotask, so should not use network or disk.
* Deserializes the macrotask binary and returns the RDD that this macrotask will operate on, as
* well as the ExecutionMonotask that will perform the computation. This function is run within a
* compute monotask, so should not use network or disk.
*/
def getMonotasks(context: TaskContextImpl): Seq[Monotask]
def getExecutionMonotask(context: TaskContextImpl): (RDD[_], ExecutionMonotask[_, _])

/**
* Returns the monotasks that need to be run in order to execute this macrotask. This function is
* run within a compute monotask, so should not use network or disk.
*/
def getMonotasks(context: TaskContextImpl): Seq[Monotask] = {
val (rdd, executionMonotask) = getExecutionMonotask(context)
val resultSerializationMonotask =
new ResultSerializationMonotask(context, executionMonotask.resultBlockId)
resultSerializationMonotask.addDependency(executionMonotask)

val inputMonotasks =
rdd.buildDag(partition, dependencyIdToPartitions, context, executionMonotask)
val leaves = inputMonotasks.filter(_.dependents.isEmpty)
leaves.foreach(resultSerializationMonotask.addDependency(_))

inputMonotasks ++ Seq(executionMonotask, resultSerializationMonotask)
}
}

/**
Expand Down
Expand Up @@ -23,8 +23,7 @@ import scala.reflect.ClassTag

import org.apache.spark.{Logging, Partition, TaskContext, TaskContextImpl}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.monotasks.Monotask
import org.apache.spark.monotasks.compute.ResultMonotask
import org.apache.spark.monotasks.compute.{ExecutionMonotask, ResultMonotask}
import org.apache.spark.rdd.RDD

/**
Expand All @@ -49,16 +48,11 @@ private[spark] class ResultMacrotask[T, U: ClassTag](

override def toString = s"ResultTask($stageId, ${partition.index})"

// Deserializes the task binary and creates the rest of the monotasks needed to run the
// macrotask.
override def getMonotasks(context: TaskContextImpl): Seq[Monotask] = {
override def getExecutionMonotask(context: TaskContextImpl): (RDD[_], ExecutionMonotask[_, _]) = {
// TODO: Task.run() setups up TaskContext and sets hostname in metrics; need to do that here!
val ser = context.env.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), context.dependencyManager.replClassLoader)

val computeMonotask = new ResultMonotask(context, rdd, partition, func)
val inputMonotasks = rdd.buildDag(partition, dependencyIdToPartitions, context, computeMonotask)
inputMonotasks ++ Seq(computeMonotask)
(rdd, new ResultMonotask(context, rdd, partition, func))
}
}
Expand Up @@ -23,8 +23,7 @@ import scala.language.existentials

import org.apache.spark.{Partition, ShuffleDependency, TaskContextImpl}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.monotasks.Monotask
import org.apache.spark.monotasks.compute.ShuffleMapMonotask
import org.apache.spark.monotasks.compute.{ExecutionMonotask, ShuffleMapMonotask}
import org.apache.spark.rdd.RDD

/**
Expand All @@ -47,13 +46,10 @@ private[spark] class ShuffleMapMacrotask(

override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition.index)

override def getMonotasks(context: TaskContextImpl): Seq[Monotask] = {
override def getExecutionMonotask(context: TaskContextImpl): (RDD[_], ExecutionMonotask[_, _]) = {
val ser = context.env.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[Any, Any, _])](
ByteBuffer.wrap(taskBinary.value), context.dependencyManager.replClassLoader)

val computeMonotask = new ShuffleMapMonotask(context, rdd, partition, dep)
val inputMonotasks = rdd.buildDag(partition, dependencyIdToPartitions, context, computeMonotask)
inputMonotasks ++ Seq(computeMonotask)
(rdd, new ShuffleMapMonotask(context, rdd, partition, dep))
}
}

0 comments on commit 44357d7

Please sign in to comment.