From 4e14ba2551b4b8466bbc4ae97d8555797e37329a Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Fri, 4 Oct 2024 15:21:45 +0200 Subject: [PATCH] Optimize --- .../SparkConnectListenerBusListener.scala | 22 +- .../SparkConnectStreamingQueryCache.scala | 239 ++++++++++-------- ...SparkConnectListenerBusListenerSuite.scala | 3 +- ...SparkConnectStreamingQueryCacheSuite.scala | 14 +- 4 files changed, 160 insertions(+), 118 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala index 7a0c067ab430..445f40d25edc 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} +import java.util.concurrent.atomic.AtomicReference import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -41,7 +42,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { // The server side listener that is responsible to stream streaming query events back to client. // There is only one listener per sessionHolder, but each listener is responsible for all events // of all streaming queries in the SparkSession. - var streamingQueryServerSideListener: Option[SparkConnectListenerBusListener] = None + var streamingQueryServerSideListener: AtomicReference[SparkConnectListenerBusListener] = + new AtomicReference() // The cache for QueryStartedEvent, key is query runId and value is the actual QueryStartedEvent. // Events for corresponding query will be sent back to client with // the WriteStreamOperationStart response, so that the client can handle the event before @@ -50,10 +52,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { val streamingQueryStartedEventCache : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap() - val lock = new Object() - - def isServerSideListenerRegistered: Boolean = lock.synchronized { - streamingQueryServerSideListener.isDefined + def isServerSideListenerRegistered: Boolean = { + streamingQueryServerSideListener.getAcquire() != null } /** @@ -65,10 +65,10 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * @param responseObserver * the responseObserver created from the first long running executeThread. */ - def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = lock.synchronized { + def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val serverListener = new SparkConnectListenerBusListener(this, responseObserver) sessionHolder.session.streams.addListener(serverListener) - streamingQueryServerSideListener = Some(serverListener) + streamingQueryServerSideListener.setRelease(serverListener) } /** @@ -77,13 +77,13 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * exception. It removes the listener from the session, clears the cache. Also it sends back the * final ResultComplete response. */ - def cleanUp(): Unit = lock.synchronized { - streamingQueryServerSideListener.foreach { listener => + def cleanUp(): Unit = { + var listener = streamingQueryServerSideListener.getAndSet(null) + if (listener != null) { sessionHolder.session.streams.removeListener(listener) listener.sendResultComplete() + streamingQueryStartedEventCache.clear() } - streamingQueryStartedEventCache.clear() - streamingQueryServerSideListener = None } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 48492bac6234..3da2548b456e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.connect.service -import java.util.concurrent.Executors -import java.util.concurrent.ScheduledExecutorService -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.AtomicReference -import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} @@ -61,36 +58,34 @@ private[connect] class SparkConnectStreamingQueryCache( sessionHolder: SessionHolder, query: StreamingQuery, tags: Set[String], - operationId: String): Unit = queryCacheLock.synchronized { - taggedQueriesLock.synchronized { - val value = QueryCacheValue( - userId = sessionHolder.userId, - sessionId = sessionHolder.sessionId, - session = sessionHolder.session, - query = query, - operationId = operationId, - expiresAtMs = None) - - val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) - tags.foreach { tag => - taggedQueries - .getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]) - .addOne(queryKey) - } - - queryCache.put(queryKey, value) match { - case Some(existing) => // Query is being replace. Not really expected. + operationId: String): Unit = { + val value = QueryCacheValue( + userId = sessionHolder.userId, + sessionId = sessionHolder.sessionId, + session = sessionHolder.session, + query = query, + operationId = operationId, + expiresAtMs = None) + + val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) + tags.foreach { tag => addTaggedQuery(tag, queryKey) } + + queryCache.compute( + queryKey, + (key, existing) => { + if (existing != null) { // The query is being replaced: allowed, though not expected. logWarning(log"Replacing existing query in the cache (unexpected). " + log"Query Id: ${MDC(QUERY_ID, query.id)}.Existing value ${MDC(OLD_VALUE, existing)}, " + log"new value ${MDC(NEW_VALUE, value)}.") - case None => + } else { logInfo( log"Adding new query to the cache. Query Id ${MDC(QUERY_ID, query.id)}, " + log"value ${MDC(QUERY_CACHE_VALUE, value)}.") - } + } + value + }) - schedulePeriodicChecks() // Starts the scheduler thread if it hasn't started. - } + schedulePeriodicChecks() // Start the scheduler thread if it has not been started. } /** @@ -104,44 +99,35 @@ private[connect] class SparkConnectStreamingQueryCache( runId: String, tags: Set[String], session: SparkSession): Option[QueryCacheValue] = { - taggedQueriesLock.synchronized { - val key = QueryCacheKey(queryId, runId) - val result = getCachedQuery(QueryCacheKey(queryId, runId), session) - tags.foreach { tag => - taggedQueries.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]).addOne(key) - } - result - } + val queryKey = QueryCacheKey(queryId, runId) + val result = getCachedQuery(QueryCacheKey(queryId, runId), session) + tags.foreach { tag => addTaggedQuery(tag, queryKey) } + result } /** * Similar with [[getCachedQuery]] but it gets queries tagged previously. */ def getTaggedQuery(tag: String, session: SparkSession): Seq[QueryCacheValue] = { - taggedQueriesLock.synchronized { - taggedQueries - .get(tag) - .map { k => - k.flatMap(getCachedQuery(_, session)).toSeq - } - .getOrElse(Seq.empty[QueryCacheValue]) - } + val queryKeySet = Option(taggedQueries.get(tag)) + queryKeySet + .map(_.flatMap(k => getCachedQuery(k, session))) + .getOrElse(Seq.empty[QueryCacheValue]) } private def getCachedQuery( key: QueryCacheKey, session: SparkSession): Option[QueryCacheValue] = { - queryCacheLock.synchronized { - queryCache.get(key).flatMap { v => - if (v.session == session) { - v.expiresAtMs.foreach { _ => - // Extend the expiry time as the client is accessing it. - val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis - queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) - } - Some(v) - } else None // Should be rare, may be client is trying access from a different session. - } + val value = Option(queryCache.get(key)) + value.flatMap { v => + if (v.session == session) { + v.expiresAtMs.foreach { _ => + // Extend the expiry time as the client is accessing it. + val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis + queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) + } + Some(v) + } else None // Should be rare, may be client is trying access from a different session. } } @@ -154,7 +140,7 @@ private[connect] class SparkConnectStreamingQueryCache( sessionHolder: SessionHolder, blocking: Boolean = true): Seq[String] = { val operationIds = new mutable.ArrayBuffer[String]() - for ((k, v) <- queryCache) { + queryCache.forEach((k, v) => { if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) { if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) { logInfo( @@ -178,29 +164,27 @@ private[connect] class SparkConnectStreamingQueryCache( } } } - } + }) operationIds.toSeq } // Visible for testing private[service] def getCachedValue(queryId: String, runId: String): Option[QueryCacheValue] = - queryCache.get(QueryCacheKey(queryId, runId)) + Option(queryCache.get(QueryCacheKey(queryId, runId))) // Visible for testing. - private[service] def shutdown(): Unit = queryCacheLock.synchronized { + private[service] def shutdown(): Unit = { val executor = scheduledExecutor.getAndSet(null) if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } } - @GuardedBy("queryCacheLock") - private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue] - private val queryCacheLock = new Object + private val queryCache: ConcurrentMap[QueryCacheKey, QueryCacheValue] = + new ConcurrentHashMap[QueryCacheKey, QueryCacheValue] - @GuardedBy("queryCacheLock") - private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] - private val taggedQueriesLock = new Object + private[service] val taggedQueries: ConcurrentMap[String, QueryCacheKeySet] = + new ConcurrentHashMap[String, QueryCacheKeySet] private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = new AtomicReference[ScheduledExecutorService]() @@ -228,62 +212,109 @@ private[connect] class SparkConnectStreamingQueryCache( } } + private def addTaggedQuery(tag: String, queryKey: QueryCacheKey): Unit = { + taggedQueries.compute( + tag, + (k, v) => { + if (v == null || !v.addKey(queryKey)) { + // Create a new QueryCacheKeySet if the entry is absent or being removed. + var keys = mutable.HashSet.empty[QueryCacheKey] + keys.add(queryKey) + new QueryCacheKeySet(keys = keys) + } else { + v + } + }) + } + /** * Periodic maintenance task to do the following: * - Update status of query if it is inactive. Sets an expiry time for such queries * - Drop expired queries from the cache. */ - private def periodicMaintenance(): Unit = taggedQueriesLock.synchronized { + private def periodicMaintenance(): Unit = { + val nowMs = clock.getTimeMillis() - queryCacheLock.synchronized { - val nowMs = clock.getTimeMillis() + queryCache.forEach((k, v) => { + val id = k.queryId + val runId = k.runId + v.expiresAtMs match { - for ((k, v) <- queryCache) { - val id = k.queryId - val runId = k.runId - v.expiresAtMs match { + case Some(ts) if nowMs >= ts => // Expired. Drop references. + logInfo( + log"Removing references for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + + log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") + queryCache.remove(k) - case Some(ts) if nowMs >= ts => // Expired. Drop references. - logInfo( - log"Removing references for id: ${MDC(QUERY_ID, id)} " + - log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") - queryCache.remove(k) + case Some(_) => // Inactive query waiting for expiration. Do nothing. + logInfo( + log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + + log"session ${MDC(SESSION_ID, v.sessionId)}") + + case None => // Active query, check if it is stopped. Enable timeout if it is stopped. + val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty - case Some(_) => // Inactive query waiting for expiration. Do nothing. + if (!isActive) { logInfo( - log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"Marking query id: ${MDC(QUERY_ID, id)} " + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)}") - - case None => // Active query, check if it is stopped. Enable timeout if it is stopped. - val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty - - if (!isActive) { - logInfo( - log"Marking query id: ${MDC(QUERY_ID, id)} " + - log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") - val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis - queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) - // To consider: Clean up any runner registered for this query with the session holder - // for this session. Useful in case listener events are delayed (such delays are - // seen in practice, especially when users have heavy processing inside listeners). - // Currently such workers would be cleaned up when the connect session expires. - } - } + log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") + val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis + queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) + // To consider: Clean up any runner registered for this query with the session holder + // for this session. Useful in case listener events are delayed (such delays are + // seen in practice, especially when users have heavy processing inside listeners). + // Currently such workers would be cleaned up when the connect session expires. + } } + }) - taggedQueries.toArray.foreach { case (key, value) => - value.zipWithIndex.toArray.foreach { case (queryKey, i) => - if (queryCache.contains(queryKey)) { - value.remove(i) - } + // Removes any tagged queries that do not correspond to cached queries. + taggedQueries.forEach((key, value) => { + if (value.filter(k => queryCache.containsKey(k))) { + taggedQueries.remove(key, value) + } + }) + } + + case class QueryCacheKeySet(keys: mutable.HashSet[QueryCacheKey]) { + + /** Tries to add the key if the set is not empty, otherwise returns false. */ + def addKey(key: QueryCacheKey): Boolean = { + keys.synchronized { + if (keys.isEmpty) { + // The entry is about to be removed. + return false } + keys.add(key) + true + } + } - if (value.isEmpty) { - taggedQueries.remove(key) + /** Removes the key and returns true if the set is empty. */ + def removeKey(key: QueryCacheKey): Boolean = { + keys.synchronized { + if (keys.remove(key)) { + return keys.isEmpty } + false + } + } + + /** Removes entries that do not satisfy the predicate. */ + def filter(pred: QueryCacheKey => Boolean): Boolean = { + keys.synchronized { + keys.filterInPlace(k => pred(k)) + keys.isEmpty + } + } + + /** Iterates over entries, apply the function individually, and then flatten the result. */ + def flatMap[T](function: QueryCacheKey => Option[T]): Seq[T] = { + keys.synchronized { + keys.flatMap(k => function(k)).toSeq } } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala index d856ffaabc31..2404dea21d91 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala @@ -202,7 +202,8 @@ class SparkConnectListenerBusListenerSuite val listenerHolder = sessionHolder.streamingServersideListenerHolder eventually(timeout(5.seconds), interval(500.milliseconds)) { assert( - sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.isEmpty) + sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.get() == + null) assert(spark.streams.listListeners().size === listenerCntBeforeThrow) assert(listenerHolder.streamingQueryStartedEventCache.isEmpty) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala index 512a0a80c4a9..729a995f4614 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala @@ -48,6 +48,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug val queryId = UUID.randomUUID().toString val runId = UUID.randomUUID().toString + val tag = "test_tag" val mockSession = mock[SparkSession] val mockQuery = mock[StreamingQuery] val mockStreamingQueryManager = mock[StreamingQueryManager] @@ -67,13 +68,16 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug // Register the query. - sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set.empty[String], "") + sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set(tag), "") sessionMgr.getCachedValue(queryId, runId) match { case Some(v) => assert(v.sessionId == sessionHolder.sessionId) assert(v.expiresAtMs.isEmpty, "No expiry time should be set for active query") + val taggedQueries = sessionMgr.getTaggedQuery(tag, mockSession) + assert(taggedQueries.contains(v)) + case None => assert(false, "Query should be found") } @@ -127,6 +131,9 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug assert(sessionMgr.getCachedValue(queryId, runId).map(_.query).contains(mockQuery)) assert( sessionMgr.getCachedValue(queryId, restartedRunId).map(_.query).contains(restartedQuery)) + eventually(timeout(1.minute)) { + assert(sessionMgr.taggedQueries.containsKey(tag)) + } // Advance time by 1 minute and verify the first query is dropped from the cache. clock.advance(1.minute.toMillis) @@ -144,8 +151,11 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug clock.advance(1.minute.toMillis) eventually(timeout(1.minute)) { assert(sessionMgr.getCachedValue(queryId, restartedRunId).isEmpty) + assert(sessionMgr.getTaggedQuery(tag, mockSession).isEmpty) + } + eventually(timeout(1.minute)) { + assert(!sessionMgr.taggedQueries.containsKey(tag)) } - sessionMgr.shutdown() } }