Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48048][CONNECT][SS] Added client side listener support for Scala #46287

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.streaming.AvailableNowTrigger
import org.apache.spark.sql.execution.streaming.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.OneTimeTrigger
import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryStartedEvent
import org.apache.spark.sql.types.NullType
import org.apache.spark.util.SparkSerDeUtils

Expand Down Expand Up @@ -297,6 +298,11 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
.build()

val resp = ds.sparkSession.execute(startCmd).head
if (resp.getWriteStreamOperationStartResult.hasQueryStartedEventJson) {
val event = QueryStartedEvent.fromJson(
resp.getWriteStreamOperationStartResult.getQueryStartedEventJson)
ds.sparkSession.streams.streamingQueryListenerBus.postToAll(event)
}
RemoteStreamingQuery.fromStartCommandResponse(ds.sparkSession, resp)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package org.apache.spark.sql.streaming

import java.util.UUID

import org.json4s.{JObject, JString}
import org.json4s.JsonAST.JValue
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
import org.json4s.{JObject, JString, JValue}
import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc}
import org.json4s.jackson.JsonMethods.{compact, render}

Expand Down Expand Up @@ -120,6 +121,21 @@ object StreamingQueryListener extends Serializable {
}
}

private[spark] object QueryStartedEvent {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret
}

private[spark] def jsonString(event: QueryStartedEvent): String =
mapper.writeValueAsString(event)

private[spark] def fromJson(json: String): QueryStartedEvent =
mapper.readValue[QueryStartedEvent](json)
}

/**
* Event representing any progress updates in a query.
* @param progress
Expand All @@ -136,6 +152,21 @@ object StreamingQueryListener extends Serializable {
private def jsonValue: JValue = JObject("progress" -> progress.jsonValue)
}

private[spark] object QueryProgressEvent {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret
}

private[spark] def jsonString(event: QueryProgressEvent): String =
mapper.writeValueAsString(event)

private[spark] def fromJson(json: String): QueryProgressEvent =
mapper.readValue[QueryProgressEvent](json)
}

/**
* Event representing that query is idle and waiting for new data to process.
*
Expand All @@ -161,6 +192,21 @@ object StreamingQueryListener extends Serializable {
}
}

private[spark] object QueryIdleEvent {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret
}

private[spark] def jsonString(event: QueryTerminatedEvent): String =
mapper.writeValueAsString(event)

private[spark] def fromJson(json: String): QueryTerminatedEvent =
mapper.readValue[QueryTerminatedEvent](json)
}

/**
* Event representing that termination of a query.
*
Expand Down Expand Up @@ -199,4 +245,19 @@ object StreamingQueryListener extends Serializable {
("errorClassOnException" -> JString(errorClassOnException.orNull))
}
}

private[spark] object QueryTerminatedEvent {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret
}

private[spark] def jsonString(event: QueryTerminatedEvent): String =
mapper.writeValueAsString(event)

private[spark] def fromJson(json: String): QueryTerminatedEvent =
mapper.readValue[QueryTerminatedEvent](json)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming

import java.util.concurrent.CopyOnWriteArrayList

import scala.jdk.CollectionConverters._

import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, Plan, StreamingQueryEventType}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.streaming.StreamingQueryListener.{Event, QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}

class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
private val listeners = new CopyOnWriteArrayList[StreamingQueryListener]()
private var executionThread: Option[Thread] = Option.empty

val lock = new Object()

def close(): Unit = {
listeners.forEach(remove(_))
}

def append(listener: StreamingQueryListener): Unit = lock.synchronized {
listeners.add(listener)

if (listeners.size() == 1) {
var iter: Option[CloseableIterator[ExecutePlanResponse]] = Option.empty
try {
iter = Some(registerServerSideListener())
} catch {
case e: Exception =>
logWarning("Failed to add the listener, please add it again.", e)
listeners.remove(listener)
return
}
executionThread = Some(new Thread(new Runnable {
def run(): Unit = {
queryEventHandler(iter.get)
}
}))
// Start the thread
executionThread.get.start()
logInfo(
"Started the execution thread for StreamingQueryListenerBus with name: " +
executionThread.get.getName())
}
}

def remove(listener: StreamingQueryListener): Unit = lock.synchronized {
if (listeners.size() == 1) {
val cmdBuilder = Command.newBuilder()
cmdBuilder.getStreamingQueryListenerBusCommandBuilder
.setRemoveListenerBusListener(true)
try {
sparkSession.execute(cmdBuilder.build())
} catch {
case e: Exception =>
logWarning("Failed to remove the listener, please remove it again.", e)
return
}
if (executionThread.isDefined) {
executionThread.get.interrupt()
executionThread = Option.empty
}
}
listeners.remove(listener)
}

def list(): Array[StreamingQueryListener] = lock.synchronized {
listeners.asScala.toArray
}

def registerServerSideListener(): CloseableIterator[ExecutePlanResponse] = {
val cmdBuilder = Command.newBuilder()
cmdBuilder.getStreamingQueryListenerBusCommandBuilder
.setAddListenerBusListener(true)

val plan = Plan.newBuilder().setCommand(cmdBuilder.build()).build()
val iterator = sparkSession.client.execute(plan)
while (iterator.hasNext) {
val response = iterator.next()
if (response.getStreamingQueryListenerEventsResult.hasListenerBusListenerAdded &&
response.getStreamingQueryListenerEventsResult.getListenerBusListenerAdded) {
return iterator
}
}
iterator
}

def queryEventHandler(iter: CloseableIterator[ExecutePlanResponse]): Unit = {
try {
while (iter.hasNext) {
val response = iter.next()
val listenerEvents = response.getStreamingQueryListenerEventsResult.getEventsList
listenerEvents.forEach(event => {
event.getEventType match {
case StreamingQueryEventType.QUERY_PROGRESS_EVENT =>
postToAll(QueryProgressEvent.fromJson(event.getEventJson))
case StreamingQueryEventType.QUERY_IDLE_EVENT =>
postToAll(QueryIdleEvent.fromJson(event.getEventJson))
case StreamingQueryEventType.QUERY_TERMINATED_EVENT =>
postToAll(QueryTerminatedEvent.fromJson(event.getEventJson))
case _ =>
logWarning(s"Unknown StreamingQueryListener event: $event")
}
})
}
} catch {
case e: Exception =>
logWarning("Failed to handle the event, please add the listener again. ", e)
bogao007 marked this conversation as resolved.
Show resolved Hide resolved
lock.synchronized {
executionThread = Option.empty
listeners.forEach(remove(_))
}
}
}

def postToAll(event: Event): Unit = lock.synchronized {
listeners.forEach(listener =>
try {
event match {
case t: QueryStartedEvent =>
listener.onQueryStarted(t)
case t: QueryProgressEvent =>
listener.onQueryProgress(t)
case t: QueryIdleEvent =>
listener.onQueryIdle(t)
case t: QueryTerminatedEvent =>
listener.onQueryTerminated(t)
case _ => logWarning(s"Unknown StreamingQueryListener event: $event")
}
} catch {
bogao007 marked this conversation as resolved.
Show resolved Hide resolved
case e: Exception =>
logWarning(s"Listener $listener threw an exception", e)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}

import scala.jdk.CollectionConverters._

import com.google.protobuf.ByteString

import org.apache.spark.annotation.Evolving
import org.apache.spark.connect.proto.Command
import org.apache.spark.connect.proto.StreamingQueryManagerCommand
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.common.{InvalidPlanInput, StreamingListenerPacket}
import org.apache.spark.util.SparkSerDeUtils
import org.apache.spark.sql.connect.common.InvalidPlanInput

/**
* A class to manage all the [[StreamingQuery]] active in a `SparkSession`.
Expand All @@ -50,6 +47,12 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
private lazy val listenerCache: ConcurrentMap[String, StreamingQueryListener] =
new ConcurrentHashMap()

private[spark] val streamingQueryListenerBus = new StreamingQueryListenerBus(sparkSession)

private[spark] def close(): Unit = {
streamingQueryListenerBus.close()
}

/**
* Returns a list of active queries associated with this SQLContext
*
Expand Down Expand Up @@ -153,17 +156,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
* @since 3.5.0
*/
def addListener(listener: StreamingQueryListener): Unit = {
// TODO: [SPARK-44400] Improve the Listener to provide users a way to access the Spark session
// and perform arbitrary actions inside the Listener. Right now users can use
// `val spark = SparkSession.builder.getOrCreate()` to create a Spark session inside the
// Listener, but this is a legacy session instead of a connect remote session.
val id = UUID.randomUUID.toString
cacheListenerById(id, listener)
executeManagerCmd(
_.getAddListenerBuilder
.setListenerPayload(ByteString.copyFrom(SparkSerDeUtils
.serialize(StreamingListenerPacket(id, listener))))
.setId(id))
streamingQueryListenerBus.append(listener)
}

/**
Expand All @@ -172,11 +165,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
* @since 3.5.0
*/
def removeListener(listener: StreamingQueryListener): Unit = {
val id = getIdByListener(listener)
executeManagerCmd(
_.getRemoveListenerBuilder
.setId(id))
removeCachedListener(id)
streamingQueryListenerBus.remove(listener)
}

/**
Expand All @@ -185,10 +174,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
* @since 3.5.0
*/
def listListeners(): Array[StreamingQueryListener] = {
executeManagerCmd(_.setListListeners(true)).getListListeners.getListenerIdsList.asScala
.filter(listenerCache.containsKey(_))
.map(listenerCache.get(_))
.toArray
streamingQueryListenerBus.list()
}

private def executeManagerCmd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.streaming.RemoteStreamingQuery"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.streaming.RemoteStreamingQuery$"),
// Skip client side listener specific class
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.streaming.StreamingQueryListenerBus"
),

// Encoders are in the wrong JAR
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
Expand Down