Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connect.service

import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
import java.util.concurrent.atomic.AtomicReference

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
Expand All @@ -41,7 +42,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
// The server side listener that is responsible to stream streaming query events back to client.
// There is only one listener per sessionHolder, but each listener is responsible for all events
// of all streaming queries in the SparkSession.
var streamingQueryServerSideListener: Option[SparkConnectListenerBusListener] = None
var streamingQueryServerSideListener: AtomicReference[SparkConnectListenerBusListener] =
new AtomicReference()
// The cache for QueryStartedEvent, key is query runId and value is the actual QueryStartedEvent.
// Events for corresponding query will be sent back to client with
// the WriteStreamOperationStart response, so that the client can handle the event before
Expand All @@ -50,10 +52,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
val streamingQueryStartedEventCache
: ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap()

val lock = new Object()

def isServerSideListenerRegistered: Boolean = lock.synchronized {
streamingQueryServerSideListener.isDefined
def isServerSideListenerRegistered: Boolean = {
streamingQueryServerSideListener.getAcquire() != null
}

/**
Expand All @@ -65,10 +65,10 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
* @param responseObserver
* the responseObserver created from the first long running executeThread.
*/
def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = lock.synchronized {
def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val serverListener = new SparkConnectListenerBusListener(this, responseObserver)
sessionHolder.session.streams.addListener(serverListener)
streamingQueryServerSideListener = Some(serverListener)
streamingQueryServerSideListener.setRelease(serverListener)
}

/**
Expand All @@ -77,13 +77,13 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
* exception. It removes the listener from the session, clears the cache. Also it sends back the
* final ResultComplete response.
*/
def cleanUp(): Unit = lock.synchronized {
streamingQueryServerSideListener.foreach { listener =>
def cleanUp(): Unit = {
var listener = streamingQueryServerSideListener.getAndSet(null)
if (listener != null) {
sessionHolder.session.streams.removeListener(listener)
listener.sendResultComplete()
streamingQueryStartedEventCache.clear()
}
streamingQueryStartedEventCache.clear()
streamingQueryServerSideListener = None
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@

package org.apache.spark.sql.connect.service

import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
Expand Down Expand Up @@ -61,36 +58,34 @@ private[connect] class SparkConnectStreamingQueryCache(
sessionHolder: SessionHolder,
query: StreamingQuery,
tags: Set[String],
operationId: String): Unit = queryCacheLock.synchronized {
taggedQueriesLock.synchronized {
val value = QueryCacheValue(
userId = sessionHolder.userId,
sessionId = sessionHolder.sessionId,
session = sessionHolder.session,
query = query,
operationId = operationId,
expiresAtMs = None)

val queryKey = QueryCacheKey(query.id.toString, query.runId.toString)
tags.foreach { tag =>
taggedQueries
.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey])
.addOne(queryKey)
}

queryCache.put(queryKey, value) match {
case Some(existing) => // Query is being replace. Not really expected.
operationId: String): Unit = {
val value = QueryCacheValue(
userId = sessionHolder.userId,
sessionId = sessionHolder.sessionId,
session = sessionHolder.session,
query = query,
operationId = operationId,
expiresAtMs = None)

val queryKey = QueryCacheKey(query.id.toString, query.runId.toString)
tags.foreach { tag => addTaggedQuery(tag, queryKey) }

queryCache.compute(
queryKey,
(key, existing) => {
if (existing != null) { // The query is being replaced: allowed, though not expected.
logWarning(log"Replacing existing query in the cache (unexpected). " +
log"Query Id: ${MDC(QUERY_ID, query.id)}.Existing value ${MDC(OLD_VALUE, existing)}, " +
log"new value ${MDC(NEW_VALUE, value)}.")
case None =>
} else {
logInfo(
log"Adding new query to the cache. Query Id ${MDC(QUERY_ID, query.id)}, " +
log"value ${MDC(QUERY_CACHE_VALUE, value)}.")
}
}
value
})

schedulePeriodicChecks() // Starts the scheduler thread if it hasn't started.
}
schedulePeriodicChecks() // Start the scheduler thread if it has not been started.
}

/**
Expand All @@ -104,44 +99,35 @@ private[connect] class SparkConnectStreamingQueryCache(
runId: String,
tags: Set[String],
session: SparkSession): Option[QueryCacheValue] = {
taggedQueriesLock.synchronized {
val key = QueryCacheKey(queryId, runId)
val result = getCachedQuery(QueryCacheKey(queryId, runId), session)
tags.foreach { tag =>
taggedQueries.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]).addOne(key)
}
result
}
val queryKey = QueryCacheKey(queryId, runId)
val result = getCachedQuery(QueryCacheKey(queryId, runId), session)
tags.foreach { tag => addTaggedQuery(tag, queryKey) }
result
}

/**
* Similar with [[getCachedQuery]] but it gets queries tagged previously.
*/
def getTaggedQuery(tag: String, session: SparkSession): Seq[QueryCacheValue] = {
taggedQueriesLock.synchronized {
taggedQueries
.get(tag)
.map { k =>
k.flatMap(getCachedQuery(_, session)).toSeq
}
.getOrElse(Seq.empty[QueryCacheValue])
}
val queryKeySet = Option(taggedQueries.get(tag))
queryKeySet
.map(_.flatMap(k => getCachedQuery(k, session)))
.getOrElse(Seq.empty[QueryCacheValue])
}

private def getCachedQuery(
key: QueryCacheKey,
session: SparkSession): Option[QueryCacheValue] = {
queryCacheLock.synchronized {
queryCache.get(key).flatMap { v =>
if (v.session == session) {
v.expiresAtMs.foreach { _ =>
// Extend the expiry time as the client is accessing it.
val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis
queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs)))
}
Some(v)
} else None // Should be rare, may be client is trying access from a different session.
}
val value = Option(queryCache.get(key))
value.flatMap { v =>
if (v.session == session) {
v.expiresAtMs.foreach { _ =>
// Extend the expiry time as the client is accessing it.
val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis
queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs)))
}
Some(v)
} else None // Should be rare, may be client is trying access from a different session.
}
}

Expand All @@ -154,7 +140,7 @@ private[connect] class SparkConnectStreamingQueryCache(
sessionHolder: SessionHolder,
blocking: Boolean = true): Seq[String] = {
val operationIds = new mutable.ArrayBuffer[String]()
for ((k, v) <- queryCache) {
queryCache.forEach((k, v) => {
if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) {
if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) {
logInfo(
Expand All @@ -178,29 +164,27 @@ private[connect] class SparkConnectStreamingQueryCache(
}
}
}
}
})
operationIds.toSeq
}

// Visible for testing
private[service] def getCachedValue(queryId: String, runId: String): Option[QueryCacheValue] =
queryCache.get(QueryCacheKey(queryId, runId))
Option(queryCache.get(QueryCacheKey(queryId, runId)))

// Visible for testing.
private[service] def shutdown(): Unit = queryCacheLock.synchronized {
private[service] def shutdown(): Unit = {
val executor = scheduledExecutor.getAndSet(null)
if (executor != null) {
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
}

@GuardedBy("queryCacheLock")
private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue]
private val queryCacheLock = new Object
private val queryCache: ConcurrentMap[QueryCacheKey, QueryCacheValue] =
new ConcurrentHashMap[QueryCacheKey, QueryCacheValue]

@GuardedBy("queryCacheLock")
private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]]
private val taggedQueriesLock = new Object
private[service] val taggedQueries: ConcurrentMap[String, QueryCacheKeySet] =
new ConcurrentHashMap[String, QueryCacheKeySet]

private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()
Expand Down Expand Up @@ -228,62 +212,109 @@ private[connect] class SparkConnectStreamingQueryCache(
}
}

private def addTaggedQuery(tag: String, queryKey: QueryCacheKey): Unit = {
taggedQueries.compute(
tag,
(k, v) => {
if (v == null || !v.addKey(queryKey)) {
// Create a new QueryCacheKeySet if the entry is absent or being removed.
var keys = mutable.HashSet.empty[QueryCacheKey]
keys.add(queryKey)
new QueryCacheKeySet(keys = keys)
} else {
v
}
})
}

/**
* Periodic maintenance task to do the following:
* - Update status of query if it is inactive. Sets an expiry time for such queries
* - Drop expired queries from the cache.
*/
private def periodicMaintenance(): Unit = taggedQueriesLock.synchronized {
private def periodicMaintenance(): Unit = {
val nowMs = clock.getTimeMillis()

queryCacheLock.synchronized {
val nowMs = clock.getTimeMillis()
queryCache.forEach((k, v) => {
val id = k.queryId
val runId = k.runId
v.expiresAtMs match {

for ((k, v) <- queryCache) {
val id = k.queryId
val runId = k.runId
v.expiresAtMs match {
case Some(ts) if nowMs >= ts => // Expired. Drop references.
logInfo(
log"Removing references for id: ${MDC(QUERY_ID, id)} " +
log"runId: ${MDC(QUERY_RUN_ID, runId)} in " +
log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period")
queryCache.remove(k)

case Some(ts) if nowMs >= ts => // Expired. Drop references.
logInfo(
log"Removing references for id: ${MDC(QUERY_ID, id)} " +
log"runId: ${MDC(QUERY_RUN_ID, runId)} in " +
log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period")
queryCache.remove(k)
case Some(_) => // Inactive query waiting for expiration. Do nothing.
logInfo(
log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " +
log"runId: ${MDC(QUERY_RUN_ID, runId)} in " +
log"session ${MDC(SESSION_ID, v.sessionId)}")

case None => // Active query, check if it is stopped. Enable timeout if it is stopped.
val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty

case Some(_) => // Inactive query waiting for expiration. Do nothing.
if (!isActive) {
logInfo(
log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " +
log"Marking query id: ${MDC(QUERY_ID, id)} " +
log"runId: ${MDC(QUERY_RUN_ID, runId)} in " +
log"session ${MDC(SESSION_ID, v.sessionId)}")

case None => // Active query, check if it is stopped. Enable timeout if it is stopped.
val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty

if (!isActive) {
logInfo(
log"Marking query id: ${MDC(QUERY_ID, id)} " +
log"runId: ${MDC(QUERY_RUN_ID, runId)} in " +
log"session ${MDC(SESSION_ID, v.sessionId)} inactive.")
val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis
queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs)))
// To consider: Clean up any runner registered for this query with the session holder
// for this session. Useful in case listener events are delayed (such delays are
// seen in practice, especially when users have heavy processing inside listeners).
// Currently such workers would be cleaned up when the connect session expires.
}
}
log"session ${MDC(SESSION_ID, v.sessionId)} inactive.")
val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis
queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs)))
// To consider: Clean up any runner registered for this query with the session holder
// for this session. Useful in case listener events are delayed (such delays are
// seen in practice, especially when users have heavy processing inside listeners).
// Currently such workers would be cleaned up when the connect session expires.
}
}
})

taggedQueries.toArray.foreach { case (key, value) =>
value.zipWithIndex.toArray.foreach { case (queryKey, i) =>
if (queryCache.contains(queryKey)) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To whom it may concern - I'm pretty sure that this condition was wrong, so fixed it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon can you please check if the condition was right? as far as I understood, the key was supposed to be removed if no associated entry was found in queryCache. thanks!

value.remove(i)
}
// Removes any tagged queries that do not correspond to cached queries.
taggedQueries.forEach((key, value) => {
if (value.filter(k => queryCache.containsKey(k))) {
taggedQueries.remove(key, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it has to remove the element within value instead of the whole entry IIRC.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, and the code actually does that, though it's not quite obvious; value.filter will return true only when value is empty, so semantically the piece of code means,

  1. Remove elements that do not satisfy queryCache.containsKey(k) from value.
  2. If value is empty after filtering those elements, value will be removed from taggedQueries.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okie, then lgtm

}
})
}

case class QueryCacheKeySet(keys: mutable.HashSet[QueryCacheKey]) {

/** Tries to add the key if the set is not empty, otherwise returns false. */
def addKey(key: QueryCacheKey): Boolean = {
keys.synchronized {
if (keys.isEmpty) {
// The entry is about to be removed.
return false
}
keys.add(key)
true
}
}

if (value.isEmpty) {
taggedQueries.remove(key)
/** Removes the key and returns true if the set is empty. */
def removeKey(key: QueryCacheKey): Boolean = {
keys.synchronized {
if (keys.remove(key)) {
return keys.isEmpty
}
false
}
}

/** Removes entries that do not satisfy the predicate. */
def filter(pred: QueryCacheKey => Boolean): Boolean = {
keys.synchronized {
keys.filterInPlace(k => pred(k))
keys.isEmpty
}
}

/** Iterates over entries, apply the function individually, and then flatten the result. */
def flatMap[T](function: QueryCacheKey => Option[T]): Seq[T] = {
keys.synchronized {
keys.flatMap(k => function(k)).toSeq
}
}
}
Expand Down
Loading