Skip to content

Commit

Permalink
fixed SPARK-4514 by calling setLocalProperties before submitting jobs…
Browse files Browse the repository at this point in the history
… within takeAsync
  • Loading branch information
reggert committed Nov 22, 2015
1 parent 5816489 commit 8fe8000
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
*/
def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
val callSite = self.context.getCallSite
val localProperties = self.context.getLocalProperties
// Cached thread pool to handle aggregation of subtasks.
implicit val executionContext = AsyncRDDActions.futureExecutionContext
val results = new ArrayBuffer[T](num)
Expand Down Expand Up @@ -102,6 +103,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi

val buf = new Array[Array[T]](p.size)
self.context.setCallSite(callSite)
self.context.setLocalProperties(localProperties)
val job = jobSubmitter.submitJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
Expand Down
15 changes: 14 additions & 1 deletion core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont
test("getJobIdsForGroup() with takeAsync()") {
sc = new SparkContext("local", "test", new SparkConf(false))
sc.setJobGroup("my-job-group2", "description")
sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq.empty)
sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1)
val firstJobId = eventually(timeout(10 seconds)) {
firstJobFuture.jobIds.head
Expand All @@ -99,4 +99,17 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont
sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId))
}
}

test("getJobIdsForGroup() with takeAsync() across multiple partitions") {
sc = new SparkContext("local", "test", new SparkConf(false))
sc.setJobGroup("my-job-group2", "description")
sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty
val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999)
val firstJobId = eventually(timeout(10 seconds)) {
firstJobFuture.jobIds.head
}
eventually(timeout(10 seconds)) {
sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2
}
}
}

0 comments on commit 8fe8000

Please sign in to comment.