-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-49876][CONNECT] Get rid of global locks from Spark Connect Service #48350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.connect.service | ||
|
|
||
| import java.util.concurrent.Executors | ||
| import java.util.concurrent.ScheduledExecutorService | ||
| import java.util.concurrent.TimeUnit | ||
| import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} | ||
| import java.util.concurrent.atomic.AtomicReference | ||
| import javax.annotation.concurrent.GuardedBy | ||
|
|
||
| import scala.collection.mutable | ||
| import scala.concurrent.{ExecutionContext, Future} | ||
|
|
@@ -61,36 +58,34 @@ private[connect] class SparkConnectStreamingQueryCache( | |
| sessionHolder: SessionHolder, | ||
| query: StreamingQuery, | ||
| tags: Set[String], | ||
| operationId: String): Unit = queryCacheLock.synchronized { | ||
| taggedQueriesLock.synchronized { | ||
| val value = QueryCacheValue( | ||
| userId = sessionHolder.userId, | ||
| sessionId = sessionHolder.sessionId, | ||
| session = sessionHolder.session, | ||
| query = query, | ||
| operationId = operationId, | ||
| expiresAtMs = None) | ||
|
|
||
| val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) | ||
| tags.foreach { tag => | ||
| taggedQueries | ||
| .getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]) | ||
| .addOne(queryKey) | ||
| } | ||
|
|
||
| queryCache.put(queryKey, value) match { | ||
| case Some(existing) => // Query is being replace. Not really expected. | ||
| operationId: String): Unit = { | ||
| val value = QueryCacheValue( | ||
| userId = sessionHolder.userId, | ||
| sessionId = sessionHolder.sessionId, | ||
| session = sessionHolder.session, | ||
| query = query, | ||
| operationId = operationId, | ||
| expiresAtMs = None) | ||
|
|
||
| val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) | ||
| tags.foreach { tag => addTaggedQuery(tag, queryKey) } | ||
|
|
||
| queryCache.compute( | ||
| queryKey, | ||
| (key, existing) => { | ||
| if (existing != null) { // The query is being replaced: allowed, though not expected. | ||
| logWarning(log"Replacing existing query in the cache (unexpected). " + | ||
| log"Query Id: ${MDC(QUERY_ID, query.id)}.Existing value ${MDC(OLD_VALUE, existing)}, " + | ||
| log"new value ${MDC(NEW_VALUE, value)}.") | ||
| case None => | ||
| } else { | ||
| logInfo( | ||
| log"Adding new query to the cache. Query Id ${MDC(QUERY_ID, query.id)}, " + | ||
| log"value ${MDC(QUERY_CACHE_VALUE, value)}.") | ||
| } | ||
| } | ||
| value | ||
| }) | ||
|
|
||
| schedulePeriodicChecks() // Starts the scheduler thread if it hasn't started. | ||
| } | ||
| schedulePeriodicChecks() // Start the scheduler thread if it has not been started. | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -104,44 +99,35 @@ private[connect] class SparkConnectStreamingQueryCache( | |
| runId: String, | ||
| tags: Set[String], | ||
| session: SparkSession): Option[QueryCacheValue] = { | ||
| taggedQueriesLock.synchronized { | ||
| val key = QueryCacheKey(queryId, runId) | ||
| val result = getCachedQuery(QueryCacheKey(queryId, runId), session) | ||
| tags.foreach { tag => | ||
| taggedQueries.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]).addOne(key) | ||
| } | ||
| result | ||
| } | ||
| val queryKey = QueryCacheKey(queryId, runId) | ||
| val result = getCachedQuery(QueryCacheKey(queryId, runId), session) | ||
| tags.foreach { tag => addTaggedQuery(tag, queryKey) } | ||
| result | ||
| } | ||
|
|
||
| /** | ||
| * Similar with [[getCachedQuery]] but it gets queries tagged previously. | ||
| */ | ||
| def getTaggedQuery(tag: String, session: SparkSession): Seq[QueryCacheValue] = { | ||
| taggedQueriesLock.synchronized { | ||
| taggedQueries | ||
| .get(tag) | ||
| .map { k => | ||
| k.flatMap(getCachedQuery(_, session)).toSeq | ||
| } | ||
| .getOrElse(Seq.empty[QueryCacheValue]) | ||
| } | ||
| val queryKeySet = Option(taggedQueries.get(tag)) | ||
| queryKeySet | ||
| .map(_.flatMap(k => getCachedQuery(k, session))) | ||
| .getOrElse(Seq.empty[QueryCacheValue]) | ||
| } | ||
|
|
||
| private def getCachedQuery( | ||
| key: QueryCacheKey, | ||
| session: SparkSession): Option[QueryCacheValue] = { | ||
| queryCacheLock.synchronized { | ||
| queryCache.get(key).flatMap { v => | ||
| if (v.session == session) { | ||
| v.expiresAtMs.foreach { _ => | ||
| // Extend the expiry time as the client is accessing it. | ||
| val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis | ||
| queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) | ||
| } | ||
| Some(v) | ||
| } else None // Should be rare, may be client is trying access from a different session. | ||
| } | ||
| val value = Option(queryCache.get(key)) | ||
| value.flatMap { v => | ||
| if (v.session == session) { | ||
| v.expiresAtMs.foreach { _ => | ||
| // Extend the expiry time as the client is accessing it. | ||
| val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis | ||
| queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) | ||
| } | ||
| Some(v) | ||
| } else None // Should be rare, may be client is trying access from a different session. | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -154,7 +140,7 @@ private[connect] class SparkConnectStreamingQueryCache( | |
| sessionHolder: SessionHolder, | ||
| blocking: Boolean = true): Seq[String] = { | ||
| val operationIds = new mutable.ArrayBuffer[String]() | ||
| for ((k, v) <- queryCache) { | ||
| queryCache.forEach((k, v) => { | ||
| if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) { | ||
| if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) { | ||
| logInfo( | ||
|
|
@@ -178,29 +164,27 @@ private[connect] class SparkConnectStreamingQueryCache( | |
| } | ||
| } | ||
| } | ||
| } | ||
| }) | ||
| operationIds.toSeq | ||
| } | ||
|
|
||
| // Visible for testing | ||
| private[service] def getCachedValue(queryId: String, runId: String): Option[QueryCacheValue] = | ||
| queryCache.get(QueryCacheKey(queryId, runId)) | ||
| Option(queryCache.get(QueryCacheKey(queryId, runId))) | ||
|
|
||
| // Visible for testing. | ||
| private[service] def shutdown(): Unit = queryCacheLock.synchronized { | ||
| private[service] def shutdown(): Unit = { | ||
| val executor = scheduledExecutor.getAndSet(null) | ||
| if (executor != null) { | ||
| ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) | ||
| } | ||
| } | ||
|
|
||
| @GuardedBy("queryCacheLock") | ||
| private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue] | ||
| private val queryCacheLock = new Object | ||
| private val queryCache: ConcurrentMap[QueryCacheKey, QueryCacheValue] = | ||
| new ConcurrentHashMap[QueryCacheKey, QueryCacheValue] | ||
|
|
||
| @GuardedBy("queryCacheLock") | ||
| private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] | ||
| private val taggedQueriesLock = new Object | ||
| private[service] val taggedQueries: ConcurrentMap[String, QueryCacheKeySet] = | ||
| new ConcurrentHashMap[String, QueryCacheKeySet] | ||
|
|
||
| private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = | ||
| new AtomicReference[ScheduledExecutorService]() | ||
|
|
@@ -228,62 +212,109 @@ private[connect] class SparkConnectStreamingQueryCache( | |
| } | ||
| } | ||
|
|
||
| private def addTaggedQuery(tag: String, queryKey: QueryCacheKey): Unit = { | ||
| taggedQueries.compute( | ||
| tag, | ||
| (k, v) => { | ||
| if (v == null || !v.addKey(queryKey)) { | ||
| // Create a new QueryCacheKeySet if the entry is absent or being removed. | ||
| var keys = mutable.HashSet.empty[QueryCacheKey] | ||
| keys.add(queryKey) | ||
| new QueryCacheKeySet(keys = keys) | ||
| } else { | ||
| v | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| /** | ||
| * Periodic maintenance task to do the following: | ||
| * - Update status of query if it is inactive. Sets an expiry time for such queries | ||
| * - Drop expired queries from the cache. | ||
| */ | ||
| private def periodicMaintenance(): Unit = taggedQueriesLock.synchronized { | ||
| private def periodicMaintenance(): Unit = { | ||
| val nowMs = clock.getTimeMillis() | ||
|
|
||
| queryCacheLock.synchronized { | ||
| val nowMs = clock.getTimeMillis() | ||
| queryCache.forEach((k, v) => { | ||
| val id = k.queryId | ||
| val runId = k.runId | ||
| v.expiresAtMs match { | ||
|
|
||
| for ((k, v) <- queryCache) { | ||
| val id = k.queryId | ||
| val runId = k.runId | ||
| v.expiresAtMs match { | ||
| case Some(ts) if nowMs >= ts => // Expired. Drop references. | ||
| logInfo( | ||
| log"Removing references for id: ${MDC(QUERY_ID, id)} " + | ||
| log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + | ||
| log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") | ||
| queryCache.remove(k) | ||
|
|
||
| case Some(ts) if nowMs >= ts => // Expired. Drop references. | ||
| logInfo( | ||
| log"Removing references for id: ${MDC(QUERY_ID, id)} " + | ||
| log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + | ||
| log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") | ||
| queryCache.remove(k) | ||
| case Some(_) => // Inactive query waiting for expiration. Do nothing. | ||
| logInfo( | ||
| log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + | ||
| log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + | ||
| log"session ${MDC(SESSION_ID, v.sessionId)}") | ||
|
|
||
| case None => // Active query, check if it is stopped. Enable timeout if it is stopped. | ||
| val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty | ||
|
|
||
| case Some(_) => // Inactive query waiting for expiration. Do nothing. | ||
| if (!isActive) { | ||
| logInfo( | ||
| log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + | ||
| log"Marking query id: ${MDC(QUERY_ID, id)} " + | ||
| log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + | ||
| log"session ${MDC(SESSION_ID, v.sessionId)}") | ||
|
|
||
| case None => // Active query, check if it is stopped. Enable timeout if it is stopped. | ||
| val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty | ||
|
|
||
| if (!isActive) { | ||
| logInfo( | ||
| log"Marking query id: ${MDC(QUERY_ID, id)} " + | ||
| log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + | ||
| log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") | ||
| val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis | ||
| queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) | ||
| // To consider: Clean up any runner registered for this query with the session holder | ||
| // for this session. Useful in case listener events are delayed (such delays are | ||
| // seen in practice, especially when users have heavy processing inside listeners). | ||
| // Currently such workers would be cleaned up when the connect session expires. | ||
| } | ||
| } | ||
| log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") | ||
| val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis | ||
| queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) | ||
| // To consider: Clean up any runner registered for this query with the session holder | ||
| // for this session. Useful in case listener events are delayed (such delays are | ||
| // seen in practice, especially when users have heavy processing inside listeners). | ||
| // Currently such workers would be cleaned up when the connect session expires. | ||
| } | ||
| } | ||
| }) | ||
|
|
||
| taggedQueries.toArray.foreach { case (key, value) => | ||
| value.zipWithIndex.toArray.foreach { case (queryKey, i) => | ||
| if (queryCache.contains(queryKey)) { | ||
| value.remove(i) | ||
| } | ||
| // Removes any tagged queries that do not correspond to cached queries. | ||
| taggedQueries.forEach((key, value) => { | ||
| if (value.filter(k => queryCache.containsKey(k))) { | ||
| taggedQueries.remove(key, value) | ||
|
||
| } | ||
| }) | ||
| } | ||
|
|
||
| case class QueryCacheKeySet(keys: mutable.HashSet[QueryCacheKey]) { | ||
|
|
||
| /** Tries to add the key if the set is not empty, otherwise returns false. */ | ||
| def addKey(key: QueryCacheKey): Boolean = { | ||
| keys.synchronized { | ||
| if (keys.isEmpty) { | ||
| // The entry is about to be removed. | ||
| return false | ||
| } | ||
| keys.add(key) | ||
| true | ||
| } | ||
| } | ||
|
|
||
| if (value.isEmpty) { | ||
| taggedQueries.remove(key) | ||
| /** Removes the key and returns true if the set is empty. */ | ||
| def removeKey(key: QueryCacheKey): Boolean = { | ||
| keys.synchronized { | ||
| if (keys.remove(key)) { | ||
| return keys.isEmpty | ||
| } | ||
| false | ||
| } | ||
| } | ||
|
|
||
| /** Removes entries that do not satisfy the predicate. */ | ||
| def filter(pred: QueryCacheKey => Boolean): Boolean = { | ||
| keys.synchronized { | ||
| keys.filterInPlace(k => pred(k)) | ||
| keys.isEmpty | ||
| } | ||
| } | ||
|
|
||
| /** Iterates over entries, apply the function individually, and then flatten the result. */ | ||
| def flatMap[T](function: QueryCacheKey => Option[T]): Seq[T] = { | ||
| keys.synchronized { | ||
| keys.flatMap(k => function(k)).toSeq | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To whom it may concern - I'm pretty sure that this condition was wrong, so fixed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HyukjinKwon can you please check if the condition was right? as far as I understood, the key was supposed to be removed if no associated entry was found in queryCache. thanks!