diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReportsSinkMetrics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReportsSinkMetrics.java new file mode 100644 index 0000000000000..97f588f5c1b13 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReportsSinkMetrics.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read.streaming; + +import org.apache.spark.annotation.Evolving; + +import java.util.Map; + +/** + * A mix-in interface for streaming sinks to signal that they can report + * metrics. + * + * @since 3.4.0 + */ +@Evolving +public interface ReportsSinkMetrics { + /** + * Returns the metrics reported by the sink for this micro-batch + */ + Map metrics(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index f9313977fe1f2..8a89ca7b85dba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSourceMetrics, SparkDataStream} +import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, ReportsSinkMetrics, ReportsSourceMetrics, SparkDataStream} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2Relation, StreamWriterCommitProgress} import org.apache.spark.sql.streaming._ @@ -200,7 +200,16 @@ trait ProgressReporter extends Logging { } else { sinkCommitProgress.map(_ => 0L) } - val sinkProgress = SinkProgress(sink.toString, sinkOutput) + + val sinkMetrics = sink match { + case withMetrics: ReportsSinkMetrics => + withMetrics.metrics() + case _ => Map[String, String]().asJava + } + + val sinkProgress = SinkProgress( + sink.toString, sinkOutput, sinkMetrics) + val observedMetrics = extractObservedMetrics(hasNewData, lastExecution) val newProgress = new StreamingQueryProgress( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 1565658777f15..3d206e7780c70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -238,7 +238,8 @@ class SourceProgress protected[sql]( @Evolving class SinkProgress protected[sql]( val description: String, - val numOutputRows: Long) extends Serializable { + val numOutputRows: Long, + val metrics: ju.Map[String, String] = Map[String, String]().asJava) extends Serializable { /** SinkProgress without custom metrics. */ protected[sql] def this(description: String) = { @@ -255,15 +256,17 @@ class SinkProgress protected[sql]( private[sql] def jsonValue: JValue = { ("description" -> JString(description)) ~ - ("numOutputRows" -> JInt(numOutputRows)) + ("numOutputRows" -> JInt(numOutputRows)) ~ + ("metrics" -> safeMapToJValue[String](metrics, s => JString(s))) } } private[sql] object SinkProgress { val DEFAULT_NUM_OUTPUT_ROWS: Long = -1L - def apply(description: String, numOutputRows: Option[Long]): SinkProgress = - new SinkProgress(description, numOutputRows.getOrElse(DEFAULT_NUM_OUTPUT_ROWS)) + def apply(description: String, numOutputRows: Option[Long], + metrics: ju.Map[String, String] = Map[String, String]().asJava): SinkProgress = + new SinkProgress(description, numOutputRows.getOrElse(DEFAULT_NUM_OUTPUT_ROWS), metrics) } private object SafeJsonSerializer { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReportSinkMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReportSinkMetricsSuite.scala new file mode 100644 index 0000000000000..17aef18634f29 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ReportSinkMetricsSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.read.streaming.ReportsSinkMetrics +import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.sources.PackedRowWriterFactory +import org.apache.spark.sql.internal.connector.{SimpleTableProvider, SupportsStreamingUpdateAsAppend} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class ReportSinkMetricsSuite extends StreamTest { + + import testImplicits._ + + test("test ReportSinkMetrics") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + var query: StreamingQuery = null + + var metricsMap: java.util.Map[String, String] = null + + val listener = new StreamingQueryListener { + + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {} + + override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { + metricsMap = event.progress.sink.metrics + } + + override def onQueryTerminated( + event: StreamingQueryListener.QueryTerminatedEvent): Unit = {} + } + + spark.streams.addListener(listener) + + withTempDir { dir => + try { + query = + df.writeStream + .outputMode("append") + .format("org.apache.spark.sql.streaming.TestSinkProvider") + .option("checkPointLocation", dir.toString) + .start() + + inputData.addData(1, 2, 3) + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + assertResult(metricsMap) { + Map("metrics-1" -> "value-1", "metrics-2" -> "value-2").asJava + } + } finally { + if (query != null) { + query.stop() + } + + spark.streams.removeListener(listener) + } + } + } +} + + case class TestSinkRelation(override val sqlContext: SQLContext, data: DataFrame) + extends BaseRelation { + override def schema: StructType = data.schema + } + + class TestSinkProvider extends SimpleTableProvider + with DataSourceRegister + with CreatableRelationProvider with Logging { + + override def getTable(options: CaseInsensitiveStringMap): Table = { + TestSinkTable + } + + def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + + TestSinkRelation(sqlContext, data) + } + + def shortName(): String = "test" + } + + object TestSinkTable extends Table with SupportsWrite with ReportsSinkMetrics with Logging { + + override def name(): String = "test" + + override def schema(): StructType = StructType(Nil) + + override def capabilities(): java.util.Set[TableCapability] = { + java.util.EnumSet.of(TableCapability.STREAMING_WRITE) + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + new WriteBuilder with SupportsTruncate with SupportsStreamingUpdateAsAppend { + + override def truncate(): WriteBuilder = this + + override def build(): Write = { + new Write { + override def toStreaming: StreamingWrite = { + new TestSinkWrite() + } + } + } + } + } + + override def metrics(): java.util.Map[String, String] = { + Map("metrics-1" -> "value-1", "metrics-2" -> "value-2").asJava + } + } + + class TestSinkWrite() + extends StreamingWrite with Logging with Serializable { + + def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = + PackedRowWriterFactory + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} +}