diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala index 21b5e057fb77e..ff59789980ce7 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala @@ -52,8 +52,9 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( private var interrupted = false - // Time at which this sender should finish if the response stream is not finished by then. - private var deadlineTimeMillis = Long.MaxValue + // Time at which this sender should finish if the response stream is not finished by then. The + // value is updated on each execute call. + private var deadlineTimeNs = 0L // Signal to wake up when grpcCallObserver.isReady() private val grpcCallObserverReadySignal = new Object @@ -74,8 +75,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( } // For testing - private[connect] def setDeadline(deadlineMs: Long) = { - deadlineTimeMillis = deadlineMs + private[connect] def setDeadline(deadlineNs: Long) = { + deadlineTimeNs = deadlineNs wakeUp() } @@ -171,6 +172,24 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( } } + /** + * Update the deadline for this sender. The deadline is the time when this sender should finish. + */ + private def updateDeadlineTimeNs(startTime: Long): Unit = { + if (executeHolder.reattachable) { + val confSize = + SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION) + if (confSize > 0) { + deadlineTimeNs = startTime + confSize * NANOS_PER_MILLIS + return + } + } + + // We cannot use Long.MaxValue as the default timeout duration, because System.nanoTime() may + // return a negative value. We use 180 days as the maximum duration. + deadlineTimeNs = startTime + (1000L * 60L * 60L * 24L * 180L * NANOS_PER_MILLIS) + } + /** * Attach to the executionObserver, consume responses from it, and send them to grpcObserver. * @@ -188,20 +207,13 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( log"Starting for opId=${MDC(OP_ID, executeHolder.operationId)}, " + log"reattachable=${MDC(REATTACHABLE, executeHolder.reattachable)}, " + log"lastConsumedStreamIndex=${MDC(STREAM_ID, lastConsumedStreamIndex)}") + val startTime = System.nanoTime() + updateDeadlineTimeNs(startTime) var nextIndex = lastConsumedStreamIndex + 1 var finished = false - // Time at which this sender should finish if the response stream is not finished by then. - deadlineTimeMillis = if (!executeHolder.reattachable) { - Long.MaxValue - } else { - val confSize = - SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION) - if (confSize > 0) System.currentTimeMillis() + confSize else Long.MaxValue - } - // Maximum total size of responses. The response which tips over this threshold will be sent. val maximumResponseSize: Long = if (!executeHolder.reattachable) { Long.MaxValue @@ -223,7 +235,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( def streamFinished = executionObserver.getLastResponseIndex().exists(nextIndex > _) // 4. time deadline or size limit reached def deadlineLimitReached = - sentResponsesSize > maximumResponseSize || deadlineTimeMillis < System.currentTimeMillis() + sentResponsesSize > maximumResponseSize || deadlineTimeNs < System.nanoTime() logTrace(s"Trying to get next response with index=$nextIndex.") executionObserver.responseLock.synchronized { @@ -241,16 +253,16 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( // The state of interrupted, response and lastIndex are changed under executionObserver // monitor, and will notify upon state change. if (response.isEmpty) { - var timeout = Math.max(1, deadlineTimeMillis - System.currentTimeMillis()) + var timeoutNs = Math.max(1, deadlineTimeNs - System.nanoTime()) // Wake up more frequently to send the progress updates. val progressTimeout = executeHolder.sessionHolder.session.sessionState.conf .getConf(CONNECT_PROGRESS_REPORT_INTERVAL) // If the progress feature is disabled, wait for the deadline. if (progressTimeout > 0L) { - timeout = Math.min(progressTimeout, timeout) + timeoutNs = Math.min(progressTimeout * NANOS_PER_MILLIS, timeoutNs) } - logTrace(s"Wait for response to become available with timeout=$timeout ms.") - executionObserver.responseLock.wait(timeout) + logTrace(s"Wait for response to become available with timeout=$timeoutNs ns.") + executionObserver.responseLock.wait(timeoutNs / NANOS_PER_MILLIS) enqueueProgressMessage(force = true) logTrace(s"Reacquired executionObserver lock after waiting.") sleepEnd = System.nanoTime() @@ -283,7 +295,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( } else if (gotResponse) { enqueueProgressMessage() // There is a response available to be sent. - val sent = sendResponse(response.get, deadlineTimeMillis) + val sent = sendResponse(response.get, deadlineTimeNs) if (sent) { sentResponsesSize += response.get.serializedByteSize nextIndex += 1 @@ -330,14 +342,12 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( * In reattachable execution, we control the backpressure and only send when the * grpcCallObserver is in fact ready to send. * - * @param deadlineTimeMillis + * @param deadlineTimeNs * when reattachable, wait for ready stream until this deadline. * @return * true if the response was sent, false otherwise (meaning deadline passed) */ - private def sendResponse( - response: CachedStreamResponse[T], - deadlineTimeMillis: Long): Boolean = { + private def sendResponse(response: CachedStreamResponse[T], deadlineTimeNs: Long): Boolean = { if (!executeHolder.reattachable) { // no flow control in non-reattachable execute logDebug( @@ -369,11 +379,11 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message]( // 3. time deadline is reached while (!interrupted && !grpcCallObserver.isReady() && - deadlineTimeMillis >= System.currentTimeMillis()) { - val timeout = Math.max(1, deadlineTimeMillis - System.currentTimeMillis()) + deadlineTimeNs >= System.nanoTime()) { + val timeoutNs = Math.max(1, deadlineTimeNs - System.nanoTime()) var sleepStart = System.nanoTime() - logTrace(s"Wait for grpcCallObserver to become ready with timeout=$timeout ms.") - grpcCallObserverReadySignal.wait(timeout) + logTrace(s"Wait for grpcCallObserver to become ready with timeout=$timeoutNs ns.") + grpcCallObserverReadySignal.wait(timeoutNs / NANOS_PER_MILLIS) logTrace(s"Reacquired grpcCallObserverReadySignal lock after waiting.") sleepEnd = System.nanoTime() } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index 94638151f7f18..1b8fa1b08473f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -89,18 +89,18 @@ private[connect] class ExecuteHolder( private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this) - /** System.currentTimeMillis when this ExecuteHolder was created. */ - val creationTimeMs = System.currentTimeMillis() + /** System.nanoTime when this ExecuteHolder was created. */ + val creationTimeNs = System.nanoTime() /** * None if there is currently an attached RPC (grpcResponseSenders not empty or during initial - * ExecutePlan handler). Otherwise, the System.currentTimeMillis when the last RPC detached + * ExecutePlan handler). Otherwise, the System.nanoTime when the last RPC detached * (grpcResponseSenders became empty). */ - @volatile var lastAttachedRpcTimeMs: Option[Long] = None + @volatile var lastAttachedRpcTimeNs: Option[Long] = None - /** System.currentTimeMillis when this ExecuteHolder was closed. */ - private var closedTimeMs: Option[Long] = None + /** System.nanoTime when this ExecuteHolder was closed. */ + private var closedTimeNs: Option[Long] = None /** * Attached ExecuteGrpcResponseSenders that send the GRPC responses. @@ -161,13 +161,13 @@ private[connect] class ExecuteHolder( private def addGrpcResponseSender( sender: ExecuteGrpcResponseSender[proto.ExecutePlanResponse]) = synchronized { - if (closedTimeMs.isEmpty) { + if (closedTimeNs.isEmpty) { // Interrupt all other senders - there can be only one active sender. // Interrupted senders will remove themselves with removeGrpcResponseSender when they exit. grpcResponseSenders.foreach(_.interrupt()) // And add this one. grpcResponseSenders += sender - lastAttachedRpcTimeMs = None + lastAttachedRpcTimeNs = None } else { // execution is closing... interrupt it already. sender.interrupt() @@ -176,18 +176,18 @@ private[connect] class ExecuteHolder( def removeGrpcResponseSender(sender: ExecuteGrpcResponseSender[_]): Unit = synchronized { // if closed, we are shutting down and interrupting all senders already - if (closedTimeMs.isEmpty) { + if (closedTimeNs.isEmpty) { grpcResponseSenders -= sender.asInstanceOf[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] if (grpcResponseSenders.isEmpty) { - lastAttachedRpcTimeMs = Some(System.currentTimeMillis()) + lastAttachedRpcTimeNs = Some(System.nanoTime()) } } } // For testing. - private[connect] def setGrpcResponseSendersDeadline(deadlineMs: Long) = synchronized { - grpcResponseSenders.foreach(_.setDeadline(deadlineMs)) + private[connect] def setGrpcResponseSendersDeadline(deadlineNs: Long) = synchronized { + grpcResponseSenders.foreach(_.setDeadline(deadlineNs)) } // For testing @@ -201,9 +201,9 @@ private[connect] class ExecuteHolder( * don't get garbage collected. End this grace period when the initial ExecutePlan ends. */ def afterInitialRPC(): Unit = synchronized { - if (closedTimeMs.isEmpty) { + if (closedTimeNs.isEmpty) { if (grpcResponseSenders.isEmpty) { - lastAttachedRpcTimeMs = Some(System.currentTimeMillis()) + lastAttachedRpcTimeNs = Some(System.nanoTime()) } } } @@ -233,20 +233,20 @@ private[connect] class ExecuteHolder( * execution from global tracking and from its session. */ def close(): Unit = synchronized { - if (closedTimeMs.isEmpty) { + if (closedTimeNs.isEmpty) { // interrupt execution, if still running. val interrupted = runner.interrupt() // interrupt any attached grpcResponseSenders grpcResponseSenders.foreach(_.interrupt()) // if there were still any grpcResponseSenders, register detach time if (grpcResponseSenders.nonEmpty) { - lastAttachedRpcTimeMs = Some(System.currentTimeMillis()) + lastAttachedRpcTimeNs = Some(System.nanoTime()) grpcResponseSenders.clear() } if (!interrupted) { cleanup() } - closedTimeMs = Some(System.currentTimeMillis()) + closedTimeNs = Some(System.nanoTime()) } } @@ -283,9 +283,9 @@ private[connect] class ExecuteHolder( sparkSessionTags = sparkSessionTags, reattachable = reattachable, status = eventsManager.status, - creationTimeMs = creationTimeMs, - lastAttachedRpcTimeMs = lastAttachedRpcTimeMs, - closedTimeMs = closedTimeMs) + creationTimeNs = creationTimeNs, + lastAttachedRpcTimeNs = lastAttachedRpcTimeNs, + closedTimeNs = closedTimeNs) } /** Get key used by SparkConnectExecutionManager global tracker. */ @@ -358,9 +358,9 @@ case class ExecuteInfo( sparkSessionTags: Set[String], reattachable: Boolean, status: ExecuteStatus, - creationTimeMs: Long, - lastAttachedRpcTimeMs: Option[Long], - closedTimeMs: Option[Long]) { + creationTimeNs: Long, + lastAttachedRpcTimeNs: Option[Long], + closedTimeNs: Option[Long]) { def key: ExecuteKey = ExecuteKey(userId, sessionId, operationId) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index f750ca6db67a8..a156be189c650 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -30,6 +30,7 @@ import com.google.common.cache.CacheBuilder import org.apache.spark.{SparkEnv, SparkSQLException} import org.apache.spark.connect.proto import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE, CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT, CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL} import org.apache.spark.util.ThreadUtils @@ -73,7 +74,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { .build[ExecuteKey, ExecuteInfo]() /** The time when the last execution was removed. */ - private var lastExecutionTimeMs: AtomicLong = new AtomicLong(System.currentTimeMillis()) + private var lastExecutionTimeNs: AtomicLong = new AtomicLong(System.nanoTime()) /** Executor for the periodic maintenance */ private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = @@ -175,12 +176,12 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } /** - * If there are no executions, return Left with System.currentTimeMillis of last active - * execution. Otherwise return Right with list of ExecuteInfo of all executions. + * If there are no executions, return Left with System.nanoTime of last active execution. + * Otherwise return Right with list of ExecuteInfo of all executions. */ def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = { if (executions.isEmpty) { - Left(lastExecutionTimeMs.getAcquire()) + Left(lastExecutionTimeNs.getAcquire()) } else { Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq) } @@ -211,7 +212,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * Updates the last execution time after the last execution has been removed. */ private def updateLastExecutionTime(): Unit = { - lastExecutionTimeMs.getAndUpdate(prev => prev.max(System.currentTimeMillis())) + lastExecutionTimeNs.getAndUpdate(prev => prev.max(System.nanoTime())) } /** @@ -231,8 +232,9 @@ private[connect] class SparkConnectExecutionManager() extends Logging { executor.scheduleAtFixedRate( () => { try { - val timeout = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT) - periodicMaintenance(timeout) + val timeoutNs = + SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT) * NANOS_PER_MILLIS + periodicMaintenance(timeoutNs) } catch { case NonFatal(ex) => logWarning("Unexpected exception in periodic task", ex) } @@ -245,15 +247,15 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } // Visible for testing. - private[connect] def periodicMaintenance(timeout: Long): Unit = { + private[connect] def periodicMaintenance(timeoutNs: Long): Unit = { // Find any detached executions that expired and should be removed. logInfo("Started periodic run of SparkConnectExecutionManager maintenance.") - val nowMs = System.currentTimeMillis() + val nowNs = System.nanoTime() executions.forEach((_, executeHolder) => { - executeHolder.lastAttachedRpcTimeMs match { - case Some(detached) => - if (detached + timeout <= nowMs) { + executeHolder.lastAttachedRpcTimeNs match { + case Some(detachedNs) => + if (detachedNs + timeoutNs <= nowNs) { val info = executeHolder.getExecuteInfo logInfo( log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was abandoned " + @@ -268,8 +270,8 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } // For testing. - private[connect] def setAllRPCsDeadline(deadlineMs: Long) = { - executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineMs)) + private[connect] def setAllRPCsDeadline(deadlineNs: Long) = { + executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineNs)) } // For testing. diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 3c857554dc756..db430549818de 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -167,7 +167,7 @@ trait SparkConnectServerTest extends SharedSparkSession { case Right(executions) => // all rpc detached. assert( - executions.forall(_.lastAttachedRpcTimeMs.isDefined), + executions.forall(_.lastAttachedRpcTimeNs.isDefined), s"Expected no RPCs, but got $executions") } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala index 2606284c25bd5..7cab12871300a 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -43,7 +43,7 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { iter.next() // open iterator, guarantees that the RPC reached the server // expire all RPCs on server - SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1) + SparkConnectService.executionManager.setAllRPCsDeadline(System.nanoTime() - 1) assertEventuallyNoActiveRpcs() // iterator should reattach // (but not necessarily at first next, as there might have been messages buffered client side) @@ -155,7 +155,7 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { // open the iterator, guarantees that the RPC reached the server iter.next() // disconnect and remove on server - SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1) + SparkConnectService.executionManager.setAllRPCsDeadline(System.nanoTime() - 1) assertEventuallyNoActiveRpcs() SparkConnectService.executionManager.periodicMaintenance(0) assertNoActiveExecutions()