Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](

private var interrupted = false

// Time at which this sender should finish if the response stream is not finished by then.
private var deadlineTimeMillis = Long.MaxValue
// Time at which this sender should finish if the response stream is not finished by then. The
// value is updated on each execute call.
private var deadlineTimeNs = 0L

// Signal to wake up when grpcCallObserver.isReady()
private val grpcCallObserverReadySignal = new Object
Expand All @@ -74,8 +75,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
}

// For testing
private[connect] def setDeadline(deadlineMs: Long) = {
deadlineTimeMillis = deadlineMs
private[connect] def setDeadline(deadlineNs: Long) = {
deadlineTimeNs = deadlineNs
wakeUp()
}

Expand Down Expand Up @@ -171,6 +172,24 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
}
}

/**
* Update the deadline for this sender. The deadline is the time when this sender should finish.
*/
private def updateDeadlineTimeNs(startTime: Long): Unit = {
if (executeHolder.reattachable) {
val confSize =
SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION)
if (confSize > 0) {
deadlineTimeNs = startTime + confSize * NANOS_PER_MILLIS
return
}
}

// We cannot use Long.MaxValue as the default timeout duration, because System.nanoTime() may
// return a negative value. We use 180 days as the maximum duration.
deadlineTimeNs = startTime + (1000L * 60L * 60L * 24L * 180L * NANOS_PER_MILLIS)
}

/**
* Attach to the executionObserver, consume responses from it, and send them to grpcObserver.
*
Expand All @@ -188,20 +207,13 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
log"Starting for opId=${MDC(OP_ID, executeHolder.operationId)}, " +
log"reattachable=${MDC(REATTACHABLE, executeHolder.reattachable)}, " +
log"lastConsumedStreamIndex=${MDC(STREAM_ID, lastConsumedStreamIndex)}")

val startTime = System.nanoTime()
updateDeadlineTimeNs(startTime)

var nextIndex = lastConsumedStreamIndex + 1
var finished = false

// Time at which this sender should finish if the response stream is not finished by then.
deadlineTimeMillis = if (!executeHolder.reattachable) {
Long.MaxValue
} else {
val confSize =
SparkEnv.get.conf.get(CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION)
if (confSize > 0) System.currentTimeMillis() + confSize else Long.MaxValue
}

// Maximum total size of responses. The response which tips over this threshold will be sent.
val maximumResponseSize: Long = if (!executeHolder.reattachable) {
Long.MaxValue
Expand All @@ -223,7 +235,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
def streamFinished = executionObserver.getLastResponseIndex().exists(nextIndex > _)
// 4. time deadline or size limit reached
def deadlineLimitReached =
sentResponsesSize > maximumResponseSize || deadlineTimeMillis < System.currentTimeMillis()
sentResponsesSize > maximumResponseSize || deadlineTimeNs < System.nanoTime()

logTrace(s"Trying to get next response with index=$nextIndex.")
executionObserver.responseLock.synchronized {
Expand All @@ -241,16 +253,16 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
// The state of interrupted, response and lastIndex are changed under executionObserver
// monitor, and will notify upon state change.
if (response.isEmpty) {
var timeout = Math.max(1, deadlineTimeMillis - System.currentTimeMillis())
var timeoutNs = Math.max(1, deadlineTimeNs - System.nanoTime())
// Wake up more frequently to send the progress updates.
val progressTimeout = executeHolder.sessionHolder.session.sessionState.conf
.getConf(CONNECT_PROGRESS_REPORT_INTERVAL)
// If the progress feature is disabled, wait for the deadline.
if (progressTimeout > 0L) {
timeout = Math.min(progressTimeout, timeout)
timeoutNs = Math.min(progressTimeout * NANOS_PER_MILLIS, timeoutNs)
}
logTrace(s"Wait for response to become available with timeout=$timeout ms.")
executionObserver.responseLock.wait(timeout)
logTrace(s"Wait for response to become available with timeout=$timeoutNs ns.")
executionObserver.responseLock.wait(timeoutNs / NANOS_PER_MILLIS)
enqueueProgressMessage(force = true)
logTrace(s"Reacquired executionObserver lock after waiting.")
sleepEnd = System.nanoTime()
Expand Down Expand Up @@ -283,7 +295,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
} else if (gotResponse) {
enqueueProgressMessage()
// There is a response available to be sent.
val sent = sendResponse(response.get, deadlineTimeMillis)
val sent = sendResponse(response.get, deadlineTimeNs)
if (sent) {
sentResponsesSize += response.get.serializedByteSize
nextIndex += 1
Expand Down Expand Up @@ -330,14 +342,12 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
* In reattachable execution, we control the backpressure and only send when the
* grpcCallObserver is in fact ready to send.
*
* @param deadlineTimeMillis
* @param deadlineTimeNs
* when reattachable, wait for ready stream until this deadline.
* @return
* true if the response was sent, false otherwise (meaning deadline passed)
*/
private def sendResponse(
response: CachedStreamResponse[T],
deadlineTimeMillis: Long): Boolean = {
private def sendResponse(response: CachedStreamResponse[T], deadlineTimeNs: Long): Boolean = {
if (!executeHolder.reattachable) {
// no flow control in non-reattachable execute
logDebug(
Expand Down Expand Up @@ -369,11 +379,11 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
// 3. time deadline is reached
while (!interrupted &&
!grpcCallObserver.isReady() &&
deadlineTimeMillis >= System.currentTimeMillis()) {
val timeout = Math.max(1, deadlineTimeMillis - System.currentTimeMillis())
deadlineTimeNs >= System.nanoTime()) {
val timeoutNs = Math.max(1, deadlineTimeNs - System.nanoTime())
var sleepStart = System.nanoTime()
logTrace(s"Wait for grpcCallObserver to become ready with timeout=$timeout ms.")
grpcCallObserverReadySignal.wait(timeout)
logTrace(s"Wait for grpcCallObserver to become ready with timeout=$timeoutNs ns.")
grpcCallObserverReadySignal.wait(timeoutNs / NANOS_PER_MILLIS)
logTrace(s"Reacquired grpcCallObserverReadySignal lock after waiting.")
sleepEnd = System.nanoTime()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ private[connect] class ExecuteHolder(

private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this)

/** System.currentTimeMillis when this ExecuteHolder was created. */
val creationTimeMs = System.currentTimeMillis()
/** System.nanoTime when this ExecuteHolder was created. */
val creationTimeNs = System.nanoTime()

/**
* None if there is currently an attached RPC (grpcResponseSenders not empty or during initial
* ExecutePlan handler). Otherwise, the System.currentTimeMillis when the last RPC detached
* ExecutePlan handler). Otherwise, the System.nanoTime when the last RPC detached
* (grpcResponseSenders became empty).
*/
@volatile var lastAttachedRpcTimeMs: Option[Long] = None
@volatile var lastAttachedRpcTimeNs: Option[Long] = None

/** System.currentTimeMillis when this ExecuteHolder was closed. */
private var closedTimeMs: Option[Long] = None
/** System.nanoTime when this ExecuteHolder was closed. */
private var closedTimeNs: Option[Long] = None

/**
* Attached ExecuteGrpcResponseSenders that send the GRPC responses.
Expand Down Expand Up @@ -161,13 +161,13 @@ private[connect] class ExecuteHolder(

private def addGrpcResponseSender(
sender: ExecuteGrpcResponseSender[proto.ExecutePlanResponse]) = synchronized {
if (closedTimeMs.isEmpty) {
if (closedTimeNs.isEmpty) {
// Interrupt all other senders - there can be only one active sender.
// Interrupted senders will remove themselves with removeGrpcResponseSender when they exit.
grpcResponseSenders.foreach(_.interrupt())
// And add this one.
grpcResponseSenders += sender
lastAttachedRpcTimeMs = None
lastAttachedRpcTimeNs = None
} else {
// execution is closing... interrupt it already.
sender.interrupt()
Expand All @@ -176,18 +176,18 @@ private[connect] class ExecuteHolder(

def removeGrpcResponseSender(sender: ExecuteGrpcResponseSender[_]): Unit = synchronized {
// if closed, we are shutting down and interrupting all senders already
if (closedTimeMs.isEmpty) {
if (closedTimeNs.isEmpty) {
grpcResponseSenders -=
sender.asInstanceOf[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]
if (grpcResponseSenders.isEmpty) {
lastAttachedRpcTimeMs = Some(System.currentTimeMillis())
lastAttachedRpcTimeNs = Some(System.nanoTime())
}
}
}

// For testing.
private[connect] def setGrpcResponseSendersDeadline(deadlineMs: Long) = synchronized {
grpcResponseSenders.foreach(_.setDeadline(deadlineMs))
private[connect] def setGrpcResponseSendersDeadline(deadlineNs: Long) = synchronized {
grpcResponseSenders.foreach(_.setDeadline(deadlineNs))
}

// For testing
Expand All @@ -201,9 +201,9 @@ private[connect] class ExecuteHolder(
* don't get garbage collected. End this grace period when the initial ExecutePlan ends.
*/
def afterInitialRPC(): Unit = synchronized {
if (closedTimeMs.isEmpty) {
if (closedTimeNs.isEmpty) {
if (grpcResponseSenders.isEmpty) {
lastAttachedRpcTimeMs = Some(System.currentTimeMillis())
lastAttachedRpcTimeNs = Some(System.nanoTime())
}
}
}
Expand Down Expand Up @@ -233,20 +233,20 @@ private[connect] class ExecuteHolder(
* execution from global tracking and from its session.
*/
def close(): Unit = synchronized {
if (closedTimeMs.isEmpty) {
if (closedTimeNs.isEmpty) {
// interrupt execution, if still running.
val interrupted = runner.interrupt()
// interrupt any attached grpcResponseSenders
grpcResponseSenders.foreach(_.interrupt())
// if there were still any grpcResponseSenders, register detach time
if (grpcResponseSenders.nonEmpty) {
lastAttachedRpcTimeMs = Some(System.currentTimeMillis())
lastAttachedRpcTimeNs = Some(System.nanoTime())
grpcResponseSenders.clear()
}
if (!interrupted) {
cleanup()
}
closedTimeMs = Some(System.currentTimeMillis())
closedTimeNs = Some(System.nanoTime())
}
}

Expand Down Expand Up @@ -283,9 +283,9 @@ private[connect] class ExecuteHolder(
sparkSessionTags = sparkSessionTags,
reattachable = reattachable,
status = eventsManager.status,
creationTimeMs = creationTimeMs,
lastAttachedRpcTimeMs = lastAttachedRpcTimeMs,
closedTimeMs = closedTimeMs)
creationTimeNs = creationTimeNs,
lastAttachedRpcTimeNs = lastAttachedRpcTimeNs,
closedTimeNs = closedTimeNs)
}

/** Get key used by SparkConnectExecutionManager global tracker. */
Expand Down Expand Up @@ -358,9 +358,9 @@ case class ExecuteInfo(
sparkSessionTags: Set[String],
reattachable: Boolean,
status: ExecuteStatus,
creationTimeMs: Long,
lastAttachedRpcTimeMs: Option[Long],
closedTimeMs: Option[Long]) {
creationTimeNs: Long,
lastAttachedRpcTimeNs: Option[Long],
closedTimeNs: Option[Long]) {

def key: ExecuteKey = ExecuteKey(userId, sessionId, operationId)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.google.common.cache.CacheBuilder
import org.apache.spark.{SparkEnv, SparkSQLException}
import org.apache.spark.connect.proto
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE, CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT, CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL}
import org.apache.spark.util.ThreadUtils

Expand Down Expand Up @@ -73,7 +74,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
.build[ExecuteKey, ExecuteInfo]()

/** The time when the last execution was removed. */
private var lastExecutionTimeMs: AtomicLong = new AtomicLong(System.currentTimeMillis())
private var lastExecutionTimeNs: AtomicLong = new AtomicLong(System.nanoTime())

/** Executor for the periodic maintenance */
private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
Expand Down Expand Up @@ -175,12 +176,12 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}

/**
* If there are no executions, return Left with System.currentTimeMillis of last active
* execution. Otherwise return Right with list of ExecuteInfo of all executions.
* If there are no executions, return Left with System.nanoTime of last active execution.
* Otherwise return Right with list of ExecuteInfo of all executions.
*/
def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = {
if (executions.isEmpty) {
Left(lastExecutionTimeMs.getAcquire())
Left(lastExecutionTimeNs.getAcquire())
} else {
Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq)
}
Expand Down Expand Up @@ -211,7 +212,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
* Updates the last execution time after the last execution has been removed.
*/
private def updateLastExecutionTime(): Unit = {
lastExecutionTimeMs.getAndUpdate(prev => prev.max(System.currentTimeMillis()))
lastExecutionTimeNs.getAndUpdate(prev => prev.max(System.nanoTime()))
}

/**
Expand All @@ -231,8 +232,9 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
executor.scheduleAtFixedRate(
() => {
try {
val timeout = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT)
periodicMaintenance(timeout)
val timeoutNs =
SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT) * NANOS_PER_MILLIS
periodicMaintenance(timeoutNs)
} catch {
case NonFatal(ex) => logWarning("Unexpected exception in periodic task", ex)
}
Expand All @@ -245,15 +247,15 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}

// Visible for testing.
private[connect] def periodicMaintenance(timeout: Long): Unit = {
private[connect] def periodicMaintenance(timeoutNs: Long): Unit = {
// Find any detached executions that expired and should be removed.
logInfo("Started periodic run of SparkConnectExecutionManager maintenance.")

val nowMs = System.currentTimeMillis()
val nowNs = System.nanoTime()
executions.forEach((_, executeHolder) => {
executeHolder.lastAttachedRpcTimeMs match {
case Some(detached) =>
if (detached + timeout <= nowMs) {
executeHolder.lastAttachedRpcTimeNs match {
case Some(detachedNs) =>
if (detachedNs + timeoutNs <= nowNs) {
val info = executeHolder.getExecuteInfo
logInfo(
log"Found execution ${MDC(LogKeys.EXECUTE_INFO, info)} that was abandoned " +
Expand All @@ -268,8 +270,8 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}

// For testing.
private[connect] def setAllRPCsDeadline(deadlineMs: Long) = {
executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
private[connect] def setAllRPCsDeadline(deadlineNs: Long) = {
executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineNs))
}

// For testing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ trait SparkConnectServerTest extends SharedSparkSession {
case Right(executions) =>
// all rpc detached.
assert(
executions.forall(_.lastAttachedRpcTimeMs.isDefined),
executions.forall(_.lastAttachedRpcTimeNs.isDefined),
s"Expected no RPCs, but got $executions")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ReattachableExecuteSuite extends SparkConnectServerTest {

iter.next() // open iterator, guarantees that the RPC reached the server
// expire all RPCs on server
SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1)
SparkConnectService.executionManager.setAllRPCsDeadline(System.nanoTime() - 1)
assertEventuallyNoActiveRpcs()
// iterator should reattach
// (but not necessarily at first next, as there might have been messages buffered client side)
Expand Down Expand Up @@ -155,7 +155,7 @@ class ReattachableExecuteSuite extends SparkConnectServerTest {
// open the iterator, guarantees that the RPC reached the server
iter.next()
// disconnect and remove on server
SparkConnectService.executionManager.setAllRPCsDeadline(System.currentTimeMillis() - 1)
SparkConnectService.executionManager.setAllRPCsDeadline(System.nanoTime() - 1)
assertEventuallyNoActiveRpcs()
SparkConnectService.executionManager.periodicMaintenance(0)
assertNoActiveExecutions()
Expand Down
Loading