Skip to content

Commit

Permalink
SPARK-3874, Provide stable TaskContext API
Browse files Browse the repository at this point in the history
  • Loading branch information
ScrapCodes committed Oct 13, 2014
1 parent 14f222f commit ef633f5
Show file tree
Hide file tree
Showing 12 changed files with 32 additions and 96 deletions.
74 changes: 5 additions & 69 deletions core/src/main/java/org/apache/spark/TaskContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,77 +37,31 @@
* Contextual information about a task which can be read or mutated during execution.
*/
@DeveloperApi
public class TaskContext implements Serializable {
public abstract class TaskContext implements Serializable {

private int stageId;
private int partitionId;
private long attemptId;
private boolean runningLocally;
private TaskMetrics taskMetrics;

/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
* @param runningLocally whether the task is running locally in the driver JVM
* @param taskMetrics performance metrics of the task
*/
@DeveloperApi
public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally,
TaskMetrics taskMetrics) {
TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally,
TaskMetrics taskMetrics) {
this.attemptId = attemptId;
this.partitionId = partitionId;
this.runningLocally = runningLocally;
this.stageId = stageId;
this.taskMetrics = taskMetrics;
}

/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
* @param runningLocally whether the task is running locally in the driver JVM
*/
@DeveloperApi
public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) {
this.attemptId = attemptId;
this.partitionId = partitionId;
this.runningLocally = runningLocally;
this.stageId = stageId;
this.taskMetrics = TaskMetrics.empty();
}

/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
*
* @param stageId stage id
* @param partitionId index of the partition
* @param attemptId the number of attempts to execute this task
*/
@DeveloperApi
public TaskContext(int stageId, int partitionId, long attemptId) {
this.attemptId = attemptId;
this.partitionId = partitionId;
this.runningLocally = false;
this.stageId = stageId;
this.taskMetrics = TaskMetrics.empty();
}

private static ThreadLocal<TaskContext> taskContext =
new ThreadLocal<TaskContext>();

/**
* :: Internal API ::
* This is spark internal API, not intended to be called from user programs.
*/
public static void setTaskContext(TaskContext tc) {
static void setTaskContext(TaskContext tc) {
taskContext.set(tc);
}

Expand All @@ -116,7 +70,7 @@ public static TaskContext get() {
}

/** :: Internal API :: */
public static void unset() {
static void unset() {
taskContext.remove();
}

Expand Down Expand Up @@ -222,20 +176,14 @@ public void markInterrupted() {
interrupted = true;
}

@Deprecated
/** Deprecated: use getStageId() */
public int stageId() {
return stageId;
}

@Deprecated
/** Deprecated: use getPartitionId() */
public int partitionId() {
return partitionId;
}

@Deprecated
/** Deprecated: use getAttemptId() */
public long attemptId() {
return attemptId;
}
Expand All @@ -250,18 +198,6 @@ public boolean isRunningLocally() {
return runningLocally;
}

public int getStageId() {
return stageId;
}

public int getPartitionId() {
return partitionId;
}

public long getAttemptId() {
return attemptId;
}

/** ::Internal API:: */
public TaskMetrics taskMetrics() {
return taskMetrics;
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class HadoopRDD[K, V](
val jobConf = getJobConf()
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf)
context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)

// Register an on-task-completion callback to close the input stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -956,9 +956,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId,
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outfmt.newInstance
Expand Down Expand Up @@ -1027,9 +1027,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
val attemptNumber = (context.attemptId % Int.MaxValue).toInt

writer.setup(context.getStageId, context.getPartitionId, attemptNumber)
writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
try {
var count = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,14 @@ class DAGScheduler(
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
TaskContext.setTaskContext(taskContext)
new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
TaskContext.unset()
TaskContextHelper.unset()
}
} catch {
case e: Exception =>
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.ByteBuffer

import scala.collection.mutable.HashMap

import org.apache.spark.TaskContext
import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
Expand All @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {

final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, false)
TaskContext.setTaskContext(context)
context = new TaskContextImpl(stageId, partitionId, attemptId, false)
TaskContextHelper.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
Expand All @@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
runTask(context)
} finally {
context.markTaskCompleted()
TaskContext.unset()
TaskContextHelper.unset()
}
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
public void onTaskCompletion(TaskContext context) {
context.isCompleted();
context.isInterrupted();
context.getStageId();
context.getPartitionId();
context.stageId();
context.partitionId();
context.isRunningLocally();
context.addTaskCompletionListener(this);
}
Expand Down
8 changes: 4 additions & 4 deletions core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
val context = new TaskContext(0, 0, 0)
val context = new TaskContextImpl(0, 0, 0)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
Expand All @@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0)
val context = new TaskContextImpl(0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
Expand All @@ -94,15 +94,15 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, true)
val context = new TaskContextImpl(0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
}

test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
val context = new TaskContext(0, 0, 0)
val context = new TaskContextImpl(0, 0, 0)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
val tContext = new TaskContext(0, 0, 0)
val tContext = new TaskContextImpl(0, 0, 0)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}

test("all TaskCompletionListeners should be called even if some fail") {
val context = new TaskContext(0, 0, 0)
val context = new TaskContextImpl(0, 0, 0)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.storage

import org.apache.spark.TaskContext
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}

import org.mockito.Mockito._
Expand Down Expand Up @@ -62,7 +62,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)

val iterator = new ShuffleBlockFetcherIterator(
new TaskContext(0, 0, 0),
new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
Expand Down Expand Up @@ -120,7 +120,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)

val iterator = new ShuffleBlockFetcherIterator(
new TaskContext(0, 0, 0),
new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
Expand Down Expand Up @@ -169,7 +169,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
(bmId, Seq((blId1, 1L), (blId2, 1L))))

val iterator = new ShuffleBlockFetcherIterator(
new TaskContext(0, 0, 0),
new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ case class InsertIntoParquetTable(
def writeShard(context: TaskContext, iter: Iterator[Row]): Int = {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId,
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = new AppendingParquetOutputFormat(taskIdOffset)
Expand Down

0 comments on commit ef633f5

Please sign in to comment.