Skip to content

Commit

Permalink
[SPARK-47233][CONNECT][SS][2/2] Client & Server logic for Client side…
Browse files Browse the repository at this point in the history
… streaming query listener

### What changes were proposed in this pull request?

Server and client side for the client side listener.

The client should start send a `add_listener_bus_listener` RPC for the first listener ever added.
The server should start a long running thread and register a new "SparkConnectListenerBusListener" upon receiving the RPC, the listener should stream back the listener events to the client using the `responseObserver` created in the `executeHandler` of the `add_listener_bus_listener` call.

On the client side, a spark client method: `execute_long_running_command` is created to continuously receive new events from the server with a long-running iterator. The client starts a new thread for handing such events. Please see the graphs below for a more detailed illustration.

When either the last client side listener is removed, and the client sends "remove_listener_bus_listener" call, or the `send` method of `SparkConnectListenerBusListener` throws, the long-running server thread is stopped, as an effect, the final `ResultComplete` is sent to the client, closing the client's long-running iterator.

### Why are the changes needed?

Development of spark connect streaming

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

Added unit test. Removed old unit test that created for verifying server-side listener limitations.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #46037 from WweiL/SPARK-47233-client-side-listener-2.

Authored-by: Wei Liu <wei.liu@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
WweiL authored and HyukjinKwon committed Apr 16, 2024
1 parent e815012 commit 51d3efc
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
val streamingQueryStartedEventCache
: ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap()

def isServerSideListenerRegistered: Boolean = streamingQueryServerSideListener.isDefined
val lock = new Object()

def isServerSideListenerRegistered: Boolean = lock.synchronized {
streamingQueryServerSideListener.isDefined
}

/**
* The initialization of the server side listener and related resources. This method is called
Expand All @@ -62,7 +66,7 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
* @param responseObserver
* the responseObserver created from the first long running executeThread.
*/
def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = lock.synchronized {
val serverListener = new SparkConnectListenerBusListener(this, responseObserver)
sessionHolder.session.streams.addListener(serverListener)
streamingQueryServerSideListener = Some(serverListener)
Expand All @@ -76,7 +80,7 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
* the latch, so the long-running thread can proceed to send back the final ResultComplete
* response.
*/
def cleanUp(): Unit = {
def cleanUp(): Unit = lock.synchronized {
streamingQueryServerSideListener.foreach { listener =>
sessionHolder.session.streams.removeListener(listener)
}
Expand Down Expand Up @@ -106,18 +110,18 @@ private[sql] class SparkConnectListenerBusListener(
// all related sources are cleaned up, and the long-running thread will proceed to send
// the final ResultComplete response.
private def send(eventJson: String, eventType: StreamingQueryEventType): Unit = {
val event = StreamingQueryListenerEvent
.newBuilder()
.setEventJson(eventJson)
.setEventType(eventType)
.build()
try {
val event = StreamingQueryListenerEvent
.newBuilder()
.setEventJson(eventJson)
.setEventType(eventType)
.build()

val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
val eventResult = respBuilder
.addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava)
.build()
val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
val eventResult = respBuilder
.addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava)
.build()

try {
responseObserver.onNext(
ExecutePlanResponse
.newBuilder()
Expand All @@ -143,14 +147,24 @@ private[sql] class SparkConnectListenerBusListener(
}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
logDebug(
s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " +
s"Sending QueryProgressEvent to client, id: ${event.progress.id}" +
s" runId: ${event.progress.runId}, batch: ${event.progress.batchId}.")
send(event.json, StreamingQueryEventType.QUERY_PROGRESS_EVENT)
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
logDebug(
s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " +
s"Sending QueryTerminatedEvent to client, id: ${event.id} runId: ${event.runId}.")
send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT)
}

override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = {
logDebug(
s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " +
s"Sending QueryIdleEvent to client, id: ${event.id} runId: ${event.runId}.")
send(event.json, StreamingQueryEventType.QUERY_IDLE_EVENT)
}
}
25 changes: 25 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,28 @@ def execute_command(
else:
return (None, properties)

def execute_command_as_iterator(
self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None
) -> Iterator[Dict[str, Any]]:
"""
Execute given command. Similar to execute_command, but the value is returned using yield.
"""
logger.info(f"Execute command as iterator for command {self._proto_to_string(command)}")
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
for response in self._execute_and_fetch_as_iterator(req, observations or {}):
if isinstance(response, dict):
yield response
else:
raise PySparkValueError(
error_class="UNKNOWN_RESPONSE",
message_parameters={
"response": str(response),
},
)

def same_semantics(self, plan: pb2.Plan, other: pb2.Plan) -> bool:
"""
return if two plans have the same semantics.
Expand Down Expand Up @@ -1330,6 +1352,9 @@ def handle_response(
if b.HasField("streaming_query_manager_command_result"):
cmd_result = b.streaming_query_manager_command_result
yield {"streaming_query_manager_command_result": cmd_result}
if b.HasField("streaming_query_listener_events_result"):
event_result = b.streaming_query_listener_events_result
yield {"streaming_query_listener_events_result": event_result}
if b.HasField("get_resources_command_result"):
resources = {}
for key, resource in b.get_resources_command_result.resources.items():
Expand Down
196 changes: 174 additions & 22 deletions python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,27 @@

import json
import sys
import pickle
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
import warnings
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional, Union, Iterator
from threading import Thread, Lock

from pyspark.errors import StreamingQueryException, PySparkValueError
import pyspark.sql.connect.proto as pb2
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.connect import proto
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.streaming import StreamingQueryListener
from pyspark.sql.streaming.listener import (
QueryStartedEvent,
QueryProgressEvent,
QueryIdleEvent,
QueryTerminatedEvent,
)
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
StreamingQueryManager as PySparkStreamingQueryManager,
)
from pyspark.errors.exceptions.connect import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.errors import PySparkPicklingError

if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession
Expand Down Expand Up @@ -184,6 +188,7 @@ def _execute_streaming_query_cmd(
class StreamingQueryManager:
def __init__(self, session: "SparkSession") -> None:
self._session = session
self._sqlb = StreamingQueryListenerBus(self)

@property
def active(self) -> List[StreamingQuery]:
Expand Down Expand Up @@ -237,27 +242,13 @@ def resetTerminated(self) -> None:
resetTerminated.__doc__ = PySparkStreamingQueryManager.resetTerminated.__doc__

def addListener(self, listener: StreamingQueryListener) -> None:
listener._init_listener_id()
cmd = pb2.StreamingQueryManagerCommand()
expr = proto.PythonUDF()
try:
expr.command = CloudPickleSerializer().dumps(listener)
except pickle.PicklingError:
raise PySparkPicklingError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "addListener"},
)
expr.python_ver = get_python_ver()
cmd.add_listener.python_listener_payload.CopyFrom(expr)
cmd.add_listener.id = listener._id
self._execute_streaming_query_manager_cmd(cmd)
listener._set_spark_session(self._session)
self._sqlb.append(listener)

addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__

def removeListener(self, listener: StreamingQueryListener) -> None:
cmd = pb2.StreamingQueryManagerCommand()
cmd.remove_listener.id = listener._id
self._execute_streaming_query_manager_cmd(cmd)
self._sqlb.remove(listener)

removeListener.__doc__ = PySparkStreamingQueryManager.removeListener.__doc__

Expand All @@ -273,6 +264,167 @@ def _execute_streaming_query_manager_cmd(
)


class StreamingQueryListenerBus:
"""
A client side listener bus that is responsible for buffering client side listeners,
receive listener events and invoke correct listener call backs.
"""

def __init__(self, sqm: "StreamingQueryManager") -> None:
self._sqm = sqm
self._listener_bus: List[StreamingQueryListener] = []
self._execution_thread: Optional[Thread] = None
self._lock = Lock()

def append(self, listener: StreamingQueryListener) -> None:
"""
Append a listener to the local listener bus. When the added listener is
the first listener, request the server to create the server side listener
and start a thread to handle query events.
"""
with self._lock:
self._listener_bus.append(listener)

if len(self._listener_bus) == 1:
assert self._execution_thread is None
try:
result_iter = self._register_server_side_listener()
except Exception as e:
warnings.warn(
f"Failed to add the listener because of exception: {e}\n"
f"The listener is not added, please add it again."
)
self._listener_bus.remove(listener)
return
self._execution_thread = Thread(
target=self._query_event_handler, args=(result_iter,)
)
self._execution_thread.start()

def remove(self, listener: StreamingQueryListener) -> None:
"""
Remove the listener from the local listener bus.
When the listener is not presented in the listener bus, do nothing.
When the removed listener is the last listener, ask the server to remove
the server side listener.
As a result, the listener handling thread created before
will return after processing remaining listener events. This function blocks until
all events are processed.
"""
with self._lock:
if listener not in self._listener_bus:
return

if len(self._listener_bus) == 1:
cmd = pb2.StreamingQueryListenerBusCommand()
cmd.remove_listener_bus_listener = True
exec_cmd = pb2.Command()
exec_cmd.streaming_query_listener_bus_command.CopyFrom(cmd)
try:
self._sqm._session.client.execute_command(exec_cmd)
except Exception as e:
warnings.warn(
f"Failed to remove the listener because of exception: {e}\n"
f"The listener is not removed, please remove it again."
)
return
if self._execution_thread is not None:
self._execution_thread.join()
self._execution_thread = None

self._listener_bus.remove(listener)

def _register_server_side_listener(self) -> Iterator[Dict[str, Any]]:
"""
Send add listener request to the server, after received confirmation from the server,
start a new thread to handle these events.
"""
cmd = pb2.StreamingQueryListenerBusCommand()
cmd.add_listener_bus_listener = True
exec_cmd = pb2.Command()
exec_cmd.streaming_query_listener_bus_command.CopyFrom(cmd)
result_iter = self._sqm._session.client.execute_command_as_iterator(exec_cmd)
# Main thread should block until received listener_added_success message
for result in result_iter:
response = cast(
pb2.StreamingQueryListenerEventsResult,
result["streaming_query_listener_events_result"],
)
if response.HasField("listener_bus_listener_added"):
break
return result_iter

def _query_event_handler(self, iter: Iterator[Dict[str, Any]]) -> None:
"""
Handler function passed to the new thread, if there is any error while receiving
listener events, it means the connection is unstable. In this case, remove all listeners
and tell the user to add back the listeners.
"""
try:
for result in iter:
response = cast(
pb2.StreamingQueryListenerEventsResult,
result["streaming_query_listener_events_result"],
)
for event in response.events:
deserialized_event = self.deserialize(event)
self.post_to_all(deserialized_event)

except Exception as e:
warnings.warn(
"StreamingQueryListenerBus Handler thread received exception, all client side "
f"listeners are removed and handler thread is terminated. The error is: {e}"
)
with self._lock:
self._execution_thread = None
self._listener_bus.clear()
return

@staticmethod
def deserialize(
event: pb2.StreamingQueryListenerEvent,
) -> Union["QueryProgressEvent", "QueryIdleEvent", "QueryTerminatedEvent"]:
if event.event_type == proto.StreamingQueryEventType.QUERY_PROGRESS_EVENT:
return QueryProgressEvent.fromJson(json.loads(event.event_json))
elif event.event_type == proto.StreamingQueryEventType.QUERY_TERMINATED_EVENT:
return QueryTerminatedEvent.fromJson(json.loads(event.event_json))
elif event.event_type == proto.StreamingQueryEventType.QUERY_IDLE_EVENT:
return QueryIdleEvent.fromJson(json.loads(event.event_json))
else:
raise PySparkValueError(
error_class="UNKNOWN_VALUE_FOR",
message_parameters={"var": f"proto.StreamingQueryEventType: {event.event_type}"},
)

def post_to_all(
self,
event: Union[
"QueryStartedEvent", "QueryProgressEvent", "QueryIdleEvent", "QueryTerminatedEvent"
],
) -> None:
"""
Post listener events to all active listeners, note that if one listener throws,
it should not affect other listeners.
"""
with self._lock:
for listener in self._listener_bus:
try:
if isinstance(event, QueryStartedEvent):
listener.onQueryStarted(event)
elif isinstance(event, QueryProgressEvent):
listener.onQueryProgress(event)
elif isinstance(event, QueryIdleEvent):
listener.onQueryIdle(event)
elif isinstance(event, QueryTerminatedEvent):
listener.onQueryTerminated(event)
else:
warnings.warn(f"Unknown StreamingQueryListener event: {event}")
except Exception as e:
warnings.warn(f"Listener {str(listener)} threw an exception\n{e}")


def _test() -> None:
import doctest
import os
Expand Down

0 comments on commit 51d3efc

Please sign in to comment.