diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 339266a5d48b2..a50600f1488c9 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.JobWaiter +import org.apache.spark.util.ThreadUtils /** @@ -45,6 +46,7 @@ trait FutureAction[T] extends Future[T] { /** * Blocks until this action completes. + * * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf * for unbounded waiting, or a finite positive duration * @return this FutureAction @@ -53,6 +55,7 @@ trait FutureAction[T] extends Future[T] { /** * Awaits and returns the result (of type T) of this action. + * * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf * for unbounded waiting, or a finite positive duration * @throws Exception exception during action execution @@ -89,8 +92,8 @@ trait FutureAction[T] extends Future[T] { /** * Blocks and returns the result of this job. */ - @throws(classOf[Exception]) - def get(): T = Await.result(this, Duration.Inf) + @throws(classOf[SparkException]) + def get(): T = ThreadUtils.awaitResult(this, Duration.Inf) /** * Returns the job IDs run by the underlying async operation. diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index abb98f95a1ee8..79f4d06c8460e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.util.concurrent.TimeoutException import scala.collection.mutable.ListBuffer -import scala.concurrent.{Await, Future, Promise} +import scala.concurrent.{Future, Promise} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.language.postfixOps @@ -35,7 +35,7 @@ import org.json4s.jackson.JsonMethods import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.master.RecoveryState import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master. @@ -265,7 +265,7 @@ private object FaultToleranceTest extends App with Logging { } // Avoid waiting indefinitely (e.g., we could register but get no executors). - assertTrue(Await.result(f, 120 seconds)) + assertTrue(ThreadUtils.awaitResult(f, 120 seconds)) } /** @@ -318,7 +318,7 @@ private object FaultToleranceTest extends App with Logging { } try { - assertTrue(Await.result(f, 120 seconds)) + assertTrue(ThreadUtils.awaitResult(f, 120 seconds)) } catch { case e: TimeoutException => logError("Master states: " + masters.map(_.state)) @@ -422,7 +422,7 @@ private object SparkDocker { } dockerCmd.run(ProcessLogger(findIpAndLog _)) - val ip = Await.result(ipPromise.future, 30 seconds) + val ip = ThreadUtils.awaitResult(ipPromise.future, 30 seconds) val dockerId = Docker.getLastProcessId (ip, dockerId, outFile) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b443e8f0519f4..edc9be2a8a8cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -24,7 +24,7 @@ import java.util.Date import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.language.postfixOps import scala.util.Random @@ -959,7 +959,7 @@ private[deploy] class Master( */ private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = { val futureUI = asyncRebuildSparkUI(app) - Await.result(futureUI, Duration.Inf) + ThreadUtils.awaitResult(futureUI, Duration.Inf) } /** Rebuild a new SparkUI asynchronously to not block RPC event loop */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index c5a5876a896cc..21cb94142b15b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -27,10 +27,11 @@ import scala.collection.mutable import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import scala.io.Source +import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonProcessingException -import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -258,13 +259,17 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { } } + // scalastyle:off awaitresult try { Await.result(responseFuture, 10.seconds) } catch { + // scalastyle:on awaitresult case unreachable @ (_: FileNotFoundException | _: SocketException) => throw new SubmitRestConnectionException("Unable to connect to server", unreachable) case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => throw new SubmitRestProtocolException("Malformed response received from server", malformed) case timeout: TimeoutException => throw new SubmitRestConnectionException("No response from server", timeout) + case NonFatal(t) => + throw new SparkException("Exception while waiting for response", t) } } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 09ce012e4e692..cb9d389dd7ea6 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -20,7 +20,7 @@ package org.apache.spark.network import java.io.Closeable import java.nio.ByteBuffer -import scala.concurrent.{Await, Future, Promise} +import scala.concurrent.{Future, Promise} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.util.ThreadUtils private[spark] abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { @@ -100,8 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.success(new NioManagedBuffer(ret)) } }) - - Await.result(result.future, Duration.Inf) + ThreadUtils.awaitResult(result.future, Duration.Inf) } /** @@ -119,6 +119,6 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo level: StorageLevel, classTag: ClassTag[_]): Unit = { val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag) - Await.result(future, Duration.Inf) + ThreadUtils.awaitResult(future, Duration.Inf) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2950df62bf285..2761d39e37029 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -19,10 +19,11 @@ package org.apache.spark.rpc import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Awaitable} +import scala.concurrent.{Await, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.util.Utils /** @@ -65,14 +66,21 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S /** * Wait for the completed result and return it. If the result is not available within this * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. - * @param awaitable the `Awaitable` to be awaited - * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * + * @param future the `Future` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `future` * is still not ready */ - def awaitResult[T](awaitable: Awaitable[T]): T = { + def awaitResult[T](future: Future[T]): T = { + val wrapAndRethrow: PartialFunction[Throwable, T] = { + case NonFatal(t) => + throw new SparkException("Exception thrown in awaitResult", t) + } try { - Await.result(awaitable, duration) - } catch addMessageIfTimeout + // scalastyle:off awaitresult + Await.result(future, duration) + // scalastyle:on awaitresult + } catch addMessageIfTimeout.orElse(wrapAndRethrow) } } @@ -82,6 +90,7 @@ private[spark] object RpcTimeout { /** * Lookup the timeout property in the configuration and create * a RpcTimeout with the property key in the description. + * * @param conf configuration properties containing the timeout * @param timeoutProp property key for the timeout in seconds * @throws NoSuchElementException if property is not set @@ -95,6 +104,7 @@ private[spark] object RpcTimeout { * Lookup the timeout property in the configuration and create * a RpcTimeout with the property key in the description. * Uses the given default value if property is not set + * * @param conf configuration properties containing the timeout * @param timeoutProp property key for the timeout in seconds * @param defaultValue default timeout value in seconds if property not found @@ -109,6 +119,7 @@ private[spark] object RpcTimeout { * and create a RpcTimeout with the first set property key in the * description. * Uses the given default value if property is not set + * * @param conf configuration properties containing the timeout * @param timeoutPropList prioritized list of property keys for the timeout in seconds * @param defaultValue default timeout value in seconds if no properties found diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 35a6c63ad193e..22bc76b143516 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -260,7 +260,12 @@ private[spark] class BlockManager( def waitForAsyncReregister(): Unit = { val task = asyncReregisterTask if (task != null) { - Await.ready(task, Duration.Inf) + try { + Await.ready(task, Duration.Inf) + } catch { + case NonFatal(t) => + throw new Exception("Error occurred while waiting for async. reregistration", t) + } } } @@ -802,7 +807,12 @@ private[spark] class BlockManager( logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) if (level.replication > 1) { // Wait for asynchronous replication to finish - Await.ready(replicationFuture, Duration.Inf) + try { + Await.ready(replicationFuture, Duration.Inf) + } catch { + case NonFatal(t) => + throw new Exception("Error occurred while waiting for replication to finish", t) + } } if (blockWasSuccessfullyStored) { None diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 9abbf4a7a3971..5a6dbc830448a 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,12 +19,15 @@ package org.apache.spark.util import java.util.concurrent._ -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.{Await, Awaitable, ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.duration.Duration import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import org.apache.spark.SparkException + private[spark] object ThreadUtils { private val sameThreadExecutionContext = @@ -174,4 +177,21 @@ private[spark] object ThreadUtils { false // asyncMode ) } + + // scalastyle:off awaitresult + /** + * Preferred alternative to [[Await.result()]]. This method wraps and re-throws any exceptions + * thrown by the underlying [[Await]] call, ensuring that this thread's stack trace appears in + * logs. + */ + @throws(classOf[SparkException]) + def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { + try { + Await.result(awaitable, atMost) + // scalastyle:on awaitresult + } catch { + case NonFatal(t) => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 78e164cff7738..848f7d7adbc7e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1598,6 +1598,7 @@ private[spark] object Utils extends Logging { /** * Timing method based on iterations that permit JVM JIT optimization. + * * @param numIters number of iterations * @param f function to be executed. If prepare is not None, the running time of each call to f * must be an order of magnitude longer than one millisecond for accurate timing. @@ -1639,6 +1640,7 @@ private[spark] object Utils extends Logging { /** * Creates a symlink. + * * @param src absolute path to the source * @param dst relative path for the destination */ diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala index 1102aea96b548..70b6309be7d53 100644 --- a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala +++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark -import scala.concurrent.Await import scala.concurrent.duration.Duration import org.scalatest.{BeforeAndAfter, Matchers} +import org.apache.spark.util.ThreadUtils + class FutureActionSuite extends SparkFunSuite @@ -36,7 +37,7 @@ class FutureActionSuite test("simple async action") { val rdd = sc.parallelize(1 to 10, 2) val job = rdd.countAsync() - val res = Await.result(job, Duration.Inf) + val res = ThreadUtils.awaitResult(job, Duration.Inf) res should be (10) job.jobIds.size should be (1) } @@ -44,7 +45,7 @@ class FutureActionSuite test("complex async action") { val rdd = sc.parallelize(1 to 15, 3) val job = rdd.takeAsync(10) - val res = Await.result(job, Duration.Inf) + val res = ThreadUtils.awaitResult(job, Duration.Inf) res should be (1 to 10) job.jobIds.size should be (2) } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 713d5e58b4ffc..4d2b3e7f3b14b 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.{ExecutorService, TimeUnit} import scala.collection.Map import scala.collection.mutable -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -36,7 +35,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.ManualClock +import org.apache.spark.util.{ManualClock, ThreadUtils} /** * A test suite for the heartbeating behavior between the driver and the executors. @@ -231,14 +230,14 @@ class HeartbeatReceiverSuite private def addExecutorAndVerify(executorId: String): Unit = { assert( heartbeatReceiver.addExecutor(executorId).map { f => - Await.result(f, 10.seconds) + ThreadUtils.awaitResult(f, 10.seconds) } === Some(true)) } private def removeExecutorAndVerify(executorId: String): Unit = { assert( heartbeatReceiver.removeExecutor(executorId).map { f => - Await.result(f, 10.seconds) + ThreadUtils.awaitResult(f, 10.seconds) } === Some(true)) } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index c347ab8dc8020..a3490fc79e458 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.util.concurrent.Semaphore -import scala.concurrent.Await import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.Future @@ -28,6 +27,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.util.ThreadUtils /** * Test suite for cancelling running jobs. We run the cancellation tasks for single job action @@ -137,7 +137,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sc.clearJobGroup() val jobB = sc.parallelize(1 to 100, 2).countAsync() sc.cancelJobGroup("jobA") - val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) } + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, Duration.Inf) }.getCause assert(e.getMessage contains "cancel") // Once A is cancelled, job B should finish fairly quickly. @@ -202,7 +202,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sc.clearJobGroup() val jobB = sc.parallelize(1 to 100, 2).countAsync() sc.cancelJobGroup("jobA") - val e = intercept[SparkException] { Await.result(jobA, 5.seconds) } + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 5.seconds) }.getCause assert(e.getMessage contains "cancel") // Once A is cancelled, job B should finish fairly quickly. @@ -248,7 +248,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() Future { f.cancel() } - val e = intercept[SparkException] { f.get() } + val e = intercept[SparkException] { f.get() }.getCause assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -268,7 +268,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sem.acquire() f.cancel() } - val e = intercept[SparkException] { f.get() } + val e = intercept[SparkException] { f.get() }.getCause assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } } @@ -278,7 +278,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) Future { f.cancel() } - val e = intercept[SparkException] { f.get() } + val e = intercept[SparkException] { f.get() }.getCause assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -296,7 +296,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sem.acquire() f.cancel() } - val e = intercept[SparkException] { f.get() } + val e = intercept[SparkException] { f.get() }.getCause assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 99d5b496bcd2e..a1286523a235d 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import org.mockito.Matchers.{any, anyLong} @@ -33,6 +33,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel} import org.apache.spark.storage.memory.MemoryStore +import org.apache.spark.util.ThreadUtils /** @@ -172,15 +173,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // Have both tasks request 500 bytes, then wait until both requests have been granted: val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 500L) - assert(Await.result(t2Result1, futureTimeout) === 500L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 500L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 500L) // Have both tasks each request 500 bytes more; both should immediately return 0 as they are // both now at 1 / N val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result2, 200.millis) === 0L) - assert(Await.result(t2Result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t1Result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t2Result2, 200.millis) === 0L) } test("two tasks cannot grow past 1 / N of on-heap execution memory") { @@ -192,15 +193,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // Have both tasks request 250 bytes, then wait until both requests have been granted: val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 250L) - assert(Await.result(t2Result1, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 250L) // Have both tasks each request 500 bytes more. // We should only grant 250 bytes to each of them on this second request val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result2, futureTimeout) === 250L) - assert(Await.result(t2Result2, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t1Result2, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t2Result2, futureTimeout) === 250L) } test("tasks can block to get at least 1 / 2N of on-heap execution memory") { @@ -211,17 +212,17 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 1000L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 1000L) val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) t1MemManager.releaseExecutionMemory(250L, MemoryMode.ON_HEAP, null) // The memory freed from t1 should now be granted to t2. - assert(Await.result(t2Result1, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 250L) // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t2Result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t2Result2, 200.millis) === 0L) } test("TaskMemoryManager.cleanUpAllAllocatedMemory") { @@ -232,18 +233,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 1000L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 1000L) val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) // t1 releases all of its memory, so t2 should be able to grab all of the memory t1MemManager.cleanUpAllAllocatedMemory() - assert(Await.result(t2Result1, futureTimeout) === 500L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 500L) val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t2Result2, futureTimeout) === 500L) + assert(ThreadUtils.awaitResult(t2Result2, futureTimeout) === 500L) val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t2Result3, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t2Result3, 200.millis) === 0L) } test("tasks should not be granted a negative amount of execution memory") { @@ -254,13 +255,13 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft val futureTimeout: Duration = 20.seconds val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 700L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 700L) val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t2Result1, futureTimeout) === 300L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 300L) val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t1Result2, 200.millis) === 0L) } test("off-heap execution allocations cannot exceed limit") { @@ -270,11 +271,11 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft val tMemManager = new TaskMemoryManager(memoryManager, 1) val result1 = Future { tMemManager.acquireExecutionMemory(1000L, MemoryMode.OFF_HEAP, null) } - assert(Await.result(result1, 200.millis) === 1000L) + assert(ThreadUtils.awaitResult(result1, 200.millis) === 1000L) assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) val result2 = Future { tMemManager.acquireExecutionMemory(300L, MemoryMode.OFF_HEAP, null) } - assert(Await.result(result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(result2, 200.millis) === 0L) assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index d18bde790b40a..8cb0a295b0773 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.util.ThreadUtils class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @@ -185,22 +186,23 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim test("FutureAction result, infinite wait") { val f = sc.parallelize(1 to 100, 4) .countAsync() - assert(Await.result(f, Duration.Inf) === 100) + assert(ThreadUtils.awaitResult(f, Duration.Inf) === 100) } test("FutureAction result, finite wait") { val f = sc.parallelize(1 to 100, 4) .countAsync() - assert(Await.result(f, Duration(30, "seconds")) === 100) + assert(ThreadUtils.awaitResult(f, Duration(30, "seconds")) === 100) } test("FutureAction result, timeout") { val f = sc.parallelize(1 to 100, 4) .mapPartitions(itr => { Thread.sleep(20); itr }) .countAsync() - intercept[TimeoutException] { - Await.result(f, Duration(20, "milliseconds")) + val e = intercept[SparkException] { + ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) } + assert(e.getCause.isInstanceOf[TimeoutException]) } private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { @@ -221,7 +223,7 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim // Now allow the executors to proceed with task processing. starter.release(rdd.partitions.length) // Waiting for the result verifies that the tasks were successfully processed. - Await.result(executionContextInvoked.future, atMost = 15.seconds) + ThreadUtils.awaitResult(executionContextInvoked.future, atMost = 15.seconds) } test("SimpleFutureAction callback must not consume a thread while waiting") { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index cebac2097f380..73803ec21a567 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -35,7 +35,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Common tests for an RpcEnv implementation. @@ -415,7 +415,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) val f = endpointRef.ask[String]("Hi") - val ack = Await.result(f, 5 seconds) + val ack = ThreadUtils.awaitResult(f, 5 seconds) assert("ack" === ack) env.stop(endpointRef) @@ -435,7 +435,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") try { val f = rpcEndpointRef.ask[String]("hello") - val ack = Await.result(f, 5 seconds) + val ack = ThreadUtils.awaitResult(f, 5 seconds) assert("ack" === ack) } finally { anotherEnv.shutdown() @@ -454,9 +454,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val f = endpointRef.ask[String]("Hi") val e = intercept[SparkException] { - Await.result(f, 5 seconds) + ThreadUtils.awaitResult(f, 5 seconds) } - assert("Oops" === e.getMessage) + assert("Oops" === e.getCause.getMessage) env.stop(endpointRef) } @@ -476,9 +476,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { - Await.result(f, 5 seconds) + ThreadUtils.awaitResult(f, 5 seconds) } - assert("Oops" === e.getMessage) + assert("Oops" === e.getCause.getMessage) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -487,6 +487,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { /** * Setup an [[RpcEndpoint]] to collect all network events. + * * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events. */ private def setupNetworkEndpoint( @@ -620,10 +621,10 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") - val e = intercept[Exception] { - Await.result(f, 1 seconds) + val e = intercept[SparkException] { + ThreadUtils.awaitResult(f, 1 seconds) } - assert(e.isInstanceOf[NotSerializableException]) + assert(e.getCause.isInstanceOf[NotSerializableException]) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -754,15 +755,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // RpcTimeout.awaitResult should have added the property to the TimeoutException message assert(reply2.contains(shortTimeout.timeoutProp)) - // Ask with delayed response and allow the Future to timeout before Await.result + // Ask with delayed response and allow the Future to timeout before ThreadUtils.awaitResult val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout) + // scalastyle:off awaitresult // Allow future to complete with failure using plain Await.result, this will return // once the future is complete to verify addMessageIfTimeout was invoked val reply3 = intercept[RpcTimeoutException] { Await.result(fut3, 2000 millis) }.getMessage + // scalastyle:on awaitresult // When the future timed out, the recover callback should have used // RpcTimeout.addMessageIfTimeout to add the property to the TimeoutException message diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 994a58836bd0d..2d6543d328618 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark._ import org.apache.spark.rpc._ class NettyRpcEnvSuite extends RpcEnvSuite { @@ -34,10 +34,11 @@ class NettyRpcEnvSuite extends RpcEnvSuite { test("non-existent endpoint") { val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString - val e = intercept[RpcEndpointNotFoundException] { + val e = intercept[SparkException] { env.setupEndpointRef(env.address, "nonexist-endpoint") } - assert(e.getMessage.contains(uri)) + assert(e.getCause.isInstanceOf[RpcEndpointNotFoundException]) + assert(e.getCause.getMessage.contains(uri)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 8e509de7677c3..83288db92bb43 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import java.io.File import java.util.concurrent.TimeoutException -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -33,7 +32,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.{FakeOutputCommitter, RDD} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Unit tests for the output commit coordination functionality. @@ -159,9 +158,10 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { 0 until rdd.partitions.size, resultHandler, () => Unit) // It's an error if the job completes successfully even though no committer was authorized, // so throw an exception if the job was allowed to complete. - intercept[TimeoutException] { - Await.result(futureAction, 5 seconds) + val e = intercept[SparkException] { + ThreadUtils.awaitResult(futureAction, 5 seconds) } + assert(e.getCause.isInstanceOf[TimeoutException]) assert(tempDir.list().size === 0) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 9d1bd7ec89bc7..9ee83b76e71dc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkException, SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.util.ThreadUtils class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -124,8 +125,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { } // After downgrading to a read lock, both threads should wake up and acquire the shared // read lock. - assert(!Await.result(lock1Future, 1.seconds)) - assert(!Await.result(lock2Future, 1.seconds)) + assert(!ThreadUtils.awaitResult(lock1Future, 1.seconds)) + assert(!ThreadUtils.awaitResult(lock2Future, 1.seconds)) assert(blockInfoManager.get("block").get.readerCount === 3) } @@ -161,7 +162,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { withTaskId(winningTID) { blockInfoManager.unlock("block") } - assert(!Await.result(losingFuture, 1.seconds)) + assert(!ThreadUtils.awaitResult(losingFuture, 1.seconds)) assert(blockInfoManager.get("block").get.readerCount === 1) } @@ -262,8 +263,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { withTaskId(0) { blockInfoManager.unlock("block") } - assert(Await.result(get1Future, 1.seconds).isDefined) - assert(Await.result(get2Future, 1.seconds).isDefined) + assert(ThreadUtils.awaitResult(get1Future, 1.seconds).isDefined) + assert(ThreadUtils.awaitResult(get2Future, 1.seconds).isDefined) assert(blockInfoManager.get("block").get.readerCount === 2) } @@ -288,13 +289,14 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { blockInfoManager.unlock("block") } assert( - Await.result(Future.firstCompletedOf(Seq(write1Future, write2Future)), 1.seconds).isDefined) + ThreadUtils.awaitResult( + Future.firstCompletedOf(Seq(write1Future, write2Future)), 1.seconds).isDefined) val firstWriteWinner = if (write1Future.isCompleted) 1 else 2 withTaskId(firstWriteWinner) { blockInfoManager.unlock("block") } - assert(Await.result(write1Future, 1.seconds).isDefined) - assert(Await.result(write2Future, 1.seconds).isDefined) + assert(ThreadUtils.awaitResult(write1Future, 1.seconds).isDefined) + assert(ThreadUtils.awaitResult(write2Future, 1.seconds).isDefined) } test("removing a non-existent block throws IllegalArgumentException") { @@ -344,8 +346,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { withTaskId(0) { blockInfoManager.removeBlock("block") } - assert(Await.result(getFuture, 1.seconds).isEmpty) - assert(Await.result(writeFuture, 1.seconds).isEmpty) + assert(ThreadUtils.awaitResult(getFuture, 1.seconds).isEmpty) + assert(ThreadUtils.awaitResult(writeFuture, 1.seconds).isEmpty) } test("releaseAllLocksForTask releases write locks") { diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 6652a41b6990b..ae3b3d829f1bb 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration._ import scala.util.Random @@ -109,7 +109,7 @@ class ThreadUtilsSuite extends SparkFunSuite { val f = Future { Thread.currentThread().getName() }(ThreadUtils.sameThread) - val futureThreadName = Await.result(f, 10.seconds) + val futureThreadName = ThreadUtils.awaitResult(f, 10.seconds) assert(futureThreadName === callerThreadName) } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index a14e3e583f870..e39400e2d1840 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -192,6 +192,17 @@ This file is divided into 3 sections: ]]> + + Await\.result + + + JavaConversions diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 260dfb3f42244..94e676ded601b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ThreadUtils /** * Additional tests for code generation. @@ -43,7 +44,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } - futures.foreach(Await.result(_, 10.seconds)) + futures.foreach(ThreadUtils.awaitResult(_, 10.seconds)) } test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4091f65aecb50..415cd4d84a23f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging @@ -167,7 +168,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def waitForSubqueries(): Unit = { // fill in the result of subqueries subqueryResults.foreach { case (e, futureResult) => - val rows = Await.result(futureResult, Duration.Inf) + val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) if (rows.length > 1) { sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala index 102a9356df311..a4f42133425ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.exchange -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import org.apache.spark.broadcast @@ -81,8 +81,7 @@ case class BroadcastExchange( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - val result = Await.result(relationFuture, timeout) - result.asInstanceOf[broadcast.Broadcast[T]] + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index eb49eabcb1ba9..0d0f556d9eae3 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -23,7 +23,7 @@ import java.sql.Timestamp import java.util.Date import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Await, Promise} +import scala.concurrent.Promise import scala.concurrent.duration._ import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary @@ -132,7 +132,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - Await.result(foundAllExpectedAnswers.future, timeout) + ThreadUtils.awaitResult(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => val message = s""" diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index a1268b8e94f56..ee14b6dc8d01d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -24,7 +24,7 @@ import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Await, ExecutionContext, Future, Promise} +import scala.concurrent.{ExecutionContext, Future, Promise} import scala.concurrent.duration._ import scala.io.Source import scala.util.{Random, Try} @@ -40,7 +40,7 @@ import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -373,9 +373,10 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // slightly more conservatively than may be strictly necessary. Thread.sleep(1000) statement.cancel() - val e = intercept[SQLException] { - Await.result(f, 3.minute) - } + val e = intercept[SparkException] { + ThreadUtils.awaitResult(f, 3.minute) + }.getCause + assert(e.isInstanceOf[SQLException]) assert(e.getMessage.contains("cancelled")) // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false @@ -391,7 +392,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // might race and complete before we issue the cancel. Thread.sleep(1000) statement.cancel() - val rs1 = Await.result(sf, 3.minute) + val rs1 = ThreadUtils.awaitResult(sf, 3.minute) rs1.next() assert(rs1.getInt(1) === math.pow(5, 5)) rs1.close() @@ -814,7 +815,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl process } - Await.result(serverStarted.future, SERVER_STARTUP_TIMEOUT) + ThreadUtils.awaitResult(serverStarted.future, SERVER_STARTUP_TIMEOUT) } private def stopThriftServer(): Unit = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index f381fa4094e74..a7d870500fd02 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.receiver -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.language.{existentials, postfixOps} import scala.util.control.NonFatal @@ -213,7 +213,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) - val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) + val walRecordHandle = ThreadUtils.awaitResult(combinedFuture, blockStoreTimeout) WriteAheadLogBasedStoreResult(blockId, numRecords, walRecordHandle) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index 165e81ea41a98..71f3304f1ba73 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -23,14 +23,14 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Await, Promise} +import scala.concurrent.Promise import scala.concurrent.duration._ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * A wrapper for a WriteAheadLog that batches records before writing data. Handles aggregation @@ -80,7 +80,8 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp } } if (putSuccessfully) { - Await.result(promise.future, WriteAheadLogUtils.getBatchingTimeout(conf).milliseconds) + ThreadUtils.awaitResult( + promise.future, WriteAheadLogUtils.getBatchingTimeout(conf).milliseconds) } else { throw new IllegalStateException("close() was called on BatchedWriteAheadLog before " + s"write request with time $time could be fulfilled.") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 8c980dee2cc06..24cb5afee33c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -38,7 +38,7 @@ import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ import org.scalatest.mock.MockitoSugar -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{CompletionIterator, ManualClock, ThreadUtils, Utils} @@ -471,10 +471,11 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( // the BatchedWriteAheadLog should bubble up any exceptions that may have happened during writes val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - intercept[RuntimeException] { + val e = intercept[SparkException] { val buffer = mock[ByteBuffer] batchedWal.write(buffer, 2L) } + assert(e.getCause.getMessage === "Hello!") } // we make the write requests in separate threads so that we don't block the test thread