Skip to content

Commit

Permalink
[SPARK-1745] Move interrupted flag from TaskContext constructor (minor)
Browse files Browse the repository at this point in the history
It makes little sense to start a TaskContext that is interrupted. Indeed, I searched for all use cases of it and didn't find a single instance in which `interrupted` is true on construction.

This was inspired by reviewing #640, which adds an additional `@volatile var completed` that is similar. These are not the most urgent changes, but I wanted to push them out before I forget.

Author: Andrew Or <andrewor14@gmail.com>

Closes #675 from andrewor14/task-context and squashes the following commits:

9575e02 [Andrew Or] Add space
69455d1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into task-context
c471490 [Andrew Or] Oops, removed one flag too many. Adding it back.
85311f8 [Andrew Or] Move interrupted flag from TaskContext constructor
  • Loading branch information
andrewor14 authored and aarondav committed May 8, 2014
1 parent 44dd57f commit c3f8b78
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 22 deletions.
20 changes: 11 additions & 9 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,23 @@ import org.apache.spark.executor.TaskMetrics
*/
@DeveloperApi
class TaskContext(
val stageId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
@volatile var interrupted: Boolean = false,
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty
) extends Serializable {
val stageId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends Serializable {

@deprecated("use partitionId", "0.8.1")
def splitId = partitionId

// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]

// Set to true when the task is completed, before the onCompleteCallbacks are executed.
// Whether the corresponding task has been killed.
@volatile var interrupted: Boolean = false

// Whether the task has completed, before the onCompleteCallbacks are executed.
@volatile var completed: Boolean = false

/**
Expand All @@ -58,6 +60,6 @@ class TaskContext(
def executeOnCompleteCallbacks() {
completed = true
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach{_()}
onCompleteCallbacks.reverse.foreach { _() }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.HashMap
import scala.util.Try

import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
Expand Down Expand Up @@ -70,7 +69,7 @@ private[spark] object ShuffleMapTask {
}

// Since both the JarSet and FileSet have the same format this is used for both.
def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = {
def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in)
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContext(0, 0, 0, false, false, new TaskMetrics());
TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}

Expand Down
10 changes: 3 additions & 7 deletions core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.mock.EasyMockSugar

import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage._

// TODO: Test the CacheManager's thread-safety aspects
Expand Down Expand Up @@ -59,8 +58,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = TaskMetrics.empty)
val context = new TaskContext(0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
Expand All @@ -72,8 +70,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = TaskMetrics.empty)
val context = new TaskContext(0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
Expand All @@ -86,8 +83,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
taskMetrics = TaskMetrics.empty)
val context = new TaskContext(0, 0, 0, runningLocally = true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
Expand Down
4 changes: 1 addition & 3 deletions core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,12 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = TaskMetrics.empty)
val tContext = new TaskContext(0, 0, 0)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
} else {
// printenv isn't available so just pass the test
assert(true)
}
}

Expand Down

0 comments on commit c3f8b78

Please sign in to comment.