Skip to content

Commit

Permalink
[SPARK-48048][CONNECT][SS] Added client side listener support for Scala
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Added client side Streaming Listener support for Scala

### Why are the changes needed?

Support Streaming Listener on client side for Spark Connect which has better user experience (no breaking change compared to legacy mode) compared to previous server side listener.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added unit test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#46287 from bogao007/client-listener.

Authored-by: bogao007 <bo.gao@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
bogao007 authored and JacobZheng0927 committed May 11, 2024
1 parent 7b433c5 commit 62ce0c3
Show file tree
Hide file tree
Showing 7 changed files with 396 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.streaming.AvailableNowTrigger
import org.apache.spark.sql.execution.streaming.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.OneTimeTrigger
import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryStartedEvent
import org.apache.spark.sql.types.NullType
import org.apache.spark.util.SparkSerDeUtils

Expand Down Expand Up @@ -297,6 +298,11 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
.build()

val resp = ds.sparkSession.execute(startCmd).head
if (resp.getWriteStreamOperationStartResult.hasQueryStartedEventJson) {
val event = QueryStartedEvent.fromJson(
resp.getWriteStreamOperationStartResult.getQueryStartedEventJson)
ds.sparkSession.streams.streamingQueryListenerBus.postToAll(event)
}
RemoteStreamingQuery.fromStartCommandResponse(ds.sparkSession, resp)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.sql.streaming

import java.util.UUID

import org.json4s.{JObject, JString}
import org.json4s.JsonAST.JValue
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
import org.json4s.{JObject, JString, JValue}
import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc}
import org.json4s.jackson.JsonMethods.{compact, render}

Expand Down Expand Up @@ -120,6 +121,21 @@ object StreamingQueryListener extends Serializable {
}
}

private[spark] object QueryStartedEvent {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret
}

private[spark] def jsonString(event: QueryStartedEvent): String =
mapper.writeValueAsString(event)

private[spark] def fromJson(json: String): QueryStartedEvent =
mapper.readValue[QueryStartedEvent](json)
}

/**
* Event representing any progress updates in a query.
* @param progress
Expand All @@ -136,6 +152,21 @@ object StreamingQueryListener extends Serializable {
private def jsonValue: JValue = JObject("progress" -> progress.jsonValue)
}

private[spark] object QueryProgressEvent {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret
}

private[spark] def jsonString(event: QueryProgressEvent): String =
mapper.writeValueAsString(event)

private[spark] def fromJson(json: String): QueryProgressEvent =
mapper.readValue[QueryProgressEvent](json)
}

/**
* Event representing that query is idle and waiting for new data to process.
*
Expand All @@ -161,6 +192,21 @@ object StreamingQueryListener extends Serializable {
}
}

private[spark] object QueryIdleEvent {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret
}

private[spark] def jsonString(event: QueryTerminatedEvent): String =
mapper.writeValueAsString(event)

private[spark] def fromJson(json: String): QueryTerminatedEvent =
mapper.readValue[QueryTerminatedEvent](json)
}

/**
* Event representing that termination of a query.
*
Expand Down Expand Up @@ -199,4 +245,19 @@ object StreamingQueryListener extends Serializable {
("errorClassOnException" -> JString(errorClassOnException.orNull))
}
}

private[spark] object QueryTerminatedEvent {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret
}

private[spark] def jsonString(event: QueryTerminatedEvent): String =
mapper.writeValueAsString(event)

private[spark] def fromJson(json: String): QueryTerminatedEvent =
mapper.readValue[QueryTerminatedEvent](json)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming

import java.util.concurrent.CopyOnWriteArrayList

import scala.jdk.CollectionConverters._

import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, Plan, StreamingQueryEventType}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.streaming.StreamingQueryListener.{Event, QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}

class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
private val listeners = new CopyOnWriteArrayList[StreamingQueryListener]()
private var executionThread: Option[Thread] = Option.empty

val lock = new Object()

def close(): Unit = {
listeners.forEach(remove(_))
}

def append(listener: StreamingQueryListener): Unit = lock.synchronized {
listeners.add(listener)

if (listeners.size() == 1) {
var iter: Option[CloseableIterator[ExecutePlanResponse]] = Option.empty
try {
iter = Some(registerServerSideListener())
} catch {
case e: Exception =>
logWarning("Failed to add the listener, please add it again.", e)
listeners.remove(listener)
return
}
executionThread = Some(new Thread(new Runnable {
def run(): Unit = {
queryEventHandler(iter.get)
}
}))
// Start the thread
executionThread.get.start()
}
}

def remove(listener: StreamingQueryListener): Unit = lock.synchronized {
if (listeners.size() == 1) {
val cmdBuilder = Command.newBuilder()
cmdBuilder.getStreamingQueryListenerBusCommandBuilder
.setRemoveListenerBusListener(true)
try {
sparkSession.execute(cmdBuilder.build())
} catch {
case e: Exception =>
logWarning("Failed to remove the listener, please remove it again.", e)
return
}
if (executionThread.isDefined) {
executionThread.get.interrupt()
executionThread = Option.empty
}
}
listeners.remove(listener)
}

def list(): Array[StreamingQueryListener] = lock.synchronized {
listeners.asScala.toArray
}

def registerServerSideListener(): CloseableIterator[ExecutePlanResponse] = {
val cmdBuilder = Command.newBuilder()
cmdBuilder.getStreamingQueryListenerBusCommandBuilder
.setAddListenerBusListener(true)

val plan = Plan.newBuilder().setCommand(cmdBuilder.build()).build()
val iterator = sparkSession.client.execute(plan)
while (iterator.hasNext) {
val response = iterator.next()
if (response.getStreamingQueryListenerEventsResult.hasListenerBusListenerAdded &&
response.getStreamingQueryListenerEventsResult.getListenerBusListenerAdded) {
return iterator
}
}
iterator
}

def queryEventHandler(iter: CloseableIterator[ExecutePlanResponse]): Unit = {
try {
while (iter.hasNext) {
val response = iter.next()
val listenerEvents = response.getStreamingQueryListenerEventsResult.getEventsList
listenerEvents.forEach(event => {
event.getEventType match {
case StreamingQueryEventType.QUERY_PROGRESS_EVENT =>
postToAll(QueryProgressEvent.fromJson(event.getEventJson))
case StreamingQueryEventType.QUERY_IDLE_EVENT =>
postToAll(QueryIdleEvent.fromJson(event.getEventJson))
case StreamingQueryEventType.QUERY_TERMINATED_EVENT =>
postToAll(QueryTerminatedEvent.fromJson(event.getEventJson))
case _ =>
logWarning(s"Unknown StreamingQueryListener event: $event")
}
})
}
} catch {
case e: Exception =>
logWarning("StreamingQueryListenerBus Handler thread received exception, all client" +
" side listeners are removed and handler thread is terminated.", e)
lock.synchronized {
executionThread = Option.empty
listeners.forEach(remove(_))
}
}
}

def postToAll(event: Event): Unit = lock.synchronized {
listeners.forEach(listener =>
try {
event match {
case t: QueryStartedEvent =>
listener.onQueryStarted(t)
case t: QueryProgressEvent =>
listener.onQueryProgress(t)
case t: QueryIdleEvent =>
listener.onQueryIdle(t)
case t: QueryTerminatedEvent =>
listener.onQueryTerminated(t)
case _ => logWarning(s"Unknown StreamingQueryListener event: $event")
}
} catch {
case e: Exception =>
logWarning(s"Listener $listener threw an exception", e)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}

import scala.jdk.CollectionConverters._

import com.google.protobuf.ByteString

import org.apache.spark.annotation.Evolving
import org.apache.spark.connect.proto.Command
import org.apache.spark.connect.proto.StreamingQueryManagerCommand
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.common.{InvalidPlanInput, StreamingListenerPacket}
import org.apache.spark.util.SparkSerDeUtils
import org.apache.spark.sql.connect.common.InvalidPlanInput

/**
* A class to manage all the [[StreamingQuery]] active in a `SparkSession`.
Expand All @@ -50,6 +47,12 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
private lazy val listenerCache: ConcurrentMap[String, StreamingQueryListener] =
new ConcurrentHashMap()

private[spark] val streamingQueryListenerBus = new StreamingQueryListenerBus(sparkSession)

private[spark] def close(): Unit = {
streamingQueryListenerBus.close()
}

/**
* Returns a list of active queries associated with this SQLContext
*
Expand Down Expand Up @@ -153,17 +156,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
* @since 3.5.0
*/
def addListener(listener: StreamingQueryListener): Unit = {
// TODO: [SPARK-44400] Improve the Listener to provide users a way to access the Spark session
// and perform arbitrary actions inside the Listener. Right now users can use
// `val spark = SparkSession.builder.getOrCreate()` to create a Spark session inside the
// Listener, but this is a legacy session instead of a connect remote session.
val id = UUID.randomUUID.toString
cacheListenerById(id, listener)
executeManagerCmd(
_.getAddListenerBuilder
.setListenerPayload(ByteString.copyFrom(SparkSerDeUtils
.serialize(StreamingListenerPacket(id, listener))))
.setId(id))
streamingQueryListenerBus.append(listener)
}

/**
Expand All @@ -172,11 +165,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
* @since 3.5.0
*/
def removeListener(listener: StreamingQueryListener): Unit = {
val id = getIdByListener(listener)
executeManagerCmd(
_.getRemoveListenerBuilder
.setId(id))
removeCachedListener(id)
streamingQueryListenerBus.remove(listener)
}

/**
Expand All @@ -185,10 +174,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
* @since 3.5.0
*/
def listListeners(): Array[StreamingQueryListener] = {
executeManagerCmd(_.setListListeners(true)).getListListeners.getListenerIdsList.asScala
.filter(listenerCache.containsKey(_))
.map(listenerCache.get(_))
.toArray
streamingQueryListenerBus.list()
}

private def executeManagerCmd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.streaming.RemoteStreamingQuery"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.streaming.RemoteStreamingQuery$"),
// Skip client side listener specific class
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.streaming.StreamingQueryListenerBus"
),

// Encoders are in the wrong JAR
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
Expand Down

0 comments on commit 62ce0c3

Please sign in to comment.