From 8fe8000ee5ada0f82ab171a27cab75e8142debcb Mon Sep 17 00:00:00 2001 From: "Richard W. Eggert II" Date: Sun, 22 Nov 2015 16:17:00 -0500 Subject: [PATCH] fixed SPARK-4514 by calling setLocalProperties before submitting jobs within takeAsync --- .../org/apache/spark/rdd/AsyncRDDActions.scala | 2 ++ .../org/apache/spark/StatusTrackerSuite.scala | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 303a152565348..14f541f937b4c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -65,6 +65,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi */ def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope { val callSite = self.context.getCallSite + val localProperties = self.context.getLocalProperties // Cached thread pool to handle aggregation of subtasks. implicit val executionContext = AsyncRDDActions.futureExecutionContext val results = new ArrayBuffer[T](num) @@ -102,6 +103,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val buf = new Array[Array[T]](p.size) self.context.setCallSite(callSite) + self.context.setLocalProperties(localProperties) val job = jobSubmitter.submitJob(self, (it: Iterator[T]) => it.take(left).toArray, p, diff --git a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 61e030fc26a82..5483f2b8434aa 100644 --- a/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -90,7 +90,7 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont test("getJobIdsForGroup() with takeAsync()") { sc = new SparkContext("local", "test", new SparkConf(false)) sc.setJobGroup("my-job-group2", "description") - sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq.empty) + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty val firstJobFuture = sc.parallelize(1 to 1000, 1).takeAsync(1) val firstJobId = eventually(timeout(10 seconds)) { firstJobFuture.jobIds.head @@ -99,4 +99,17 @@ class StatusTrackerSuite extends SparkFunSuite with Matchers with LocalSparkCont sc.statusTracker.getJobIdsForGroup("my-job-group2") should be (Seq(firstJobId)) } } + + test("getJobIdsForGroup() with takeAsync() across multiple partitions") { + sc = new SparkContext("local", "test", new SparkConf(false)) + sc.setJobGroup("my-job-group2", "description") + sc.statusTracker.getJobIdsForGroup("my-job-group2") shouldBe empty + val firstJobFuture = sc.parallelize(1 to 1000, 2).takeAsync(999) + val firstJobId = eventually(timeout(10 seconds)) { + firstJobFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup("my-job-group2") should have size 2 + } + } }