-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-24565][SS] Add API for in Structured Streaming for exposing output rows of each microbatch as a DataFrame #21571
Changes from all commits
21acc73
4ac056e
3b7b20d
985a4fe
687402c
e8073ea
0763a44
6f9fdf4
9062fb9
5b4252a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,12 +24,14 @@ | |
else: | ||
intlike = (int, long) | ||
|
||
from py4j.java_gateway import java_import | ||
|
||
from pyspark import since, keyword_only | ||
from pyspark.rdd import ignore_unicode_prefix | ||
from pyspark.sql.column import _to_seq | ||
from pyspark.sql.readwriter import OptionUtils, to_str | ||
from pyspark.sql.types import * | ||
from pyspark.sql.utils import StreamingQueryException | ||
from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException | ||
|
||
__all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] | ||
|
||
|
@@ -1016,6 +1018,35 @@ def func_with_open_process_close(partition_id, iterator): | |
self._jwrite.foreach(jForeachWriter) | ||
return self | ||
|
||
@since(2.4) | ||
def foreachBatch(self, func): | ||
""" | ||
Sets the output of the streaming query to be processed using the provided | ||
function. This is supported only the in the micro-batch execution modes (that is, when the | ||
trigger is not continuous). In every micro-batch, the provided function will be called in | ||
every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier. | ||
The batchId can be used deduplicate and transactionally write the output | ||
(that is, the provided Dataset) to external systems. The output DataFrame is guaranteed | ||
to exactly same for the same batchId (assuming all operations are deterministic in the | ||
query). | ||
|
||
.. note:: Evolving. | ||
|
||
>>> def func(batch_df, batch_id): | ||
... batch_df.collect() | ||
... | ||
>>> writer = sdf.writeStream.foreach(func) | ||
""" | ||
|
||
from pyspark.java_gateway import ensure_callback_server_started | ||
gw = self._spark._sc._gateway | ||
java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*") | ||
|
||
wrapped_func = ForeachBatchFunction(self._spark, func) | ||
gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func) | ||
ensure_callback_server_started(gw) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be above otherwise there is a race that the streaming query calls this python func before the callback server is started. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not possible because the callback from JVM ForeachBatch sink to Python is made ONLY after the query is started. And the query cannot be started until this foreach() method finishes and start() is called. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am sorry if I'm mistaken but can't we still put this above? Looks weird we ensure the callback server at the end. |
||
return self | ||
|
||
@ignore_unicode_prefix | ||
@since(2.0) | ||
def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2126,6 +2126,42 @@ class WriterWithNonCallableClose(WithProcess): | |
tester.assert_invalid_writer(WriterWithNonCallableClose(), | ||
"'close' in provided object is not callable") | ||
|
||
def test_streaming_foreachBatch(self): | ||
q = None | ||
collected = dict() | ||
|
||
def collectBatch(batch_df, batch_id): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. collectBatch -> collect_batch per PEP 8. |
||
collected[batch_id] = batch_df.collect() | ||
|
||
try: | ||
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') | ||
q = df.writeStream.foreachBatch(collectBatch).start() | ||
q.processAllAvailable() | ||
self.assertTrue(0 in collected) | ||
self.assertTrue(len(collected[0]), 2) | ||
finally: | ||
if q: | ||
q.stop() | ||
|
||
def test_streaming_foreachBatch_propagates_python_errors(self): | ||
from pyspark.sql.utils import StreamingQueryException | ||
|
||
q = None | ||
|
||
def collectBatch(df, id): | ||
raise Exception("this should fail the query") | ||
|
||
try: | ||
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') | ||
q = df.writeStream.foreachBatch(collectBatch).start() | ||
q.processAllAvailable() | ||
self.fail("Expected a failure") | ||
except StreamingQueryException as e: | ||
self.assertTrue("this should fail" in str(e)) | ||
finally: | ||
if q: | ||
q.stop() | ||
|
||
def test_help_command(self): | ||
# Regression test for SPARK-5464 | ||
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,22 +79,8 @@ def _ensure_initialized(cls): | |
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") | ||
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") | ||
|
||
# start callback server | ||
# getattr will fallback to JVM, so we cannot test by hasattr() | ||
if "_callback_server" not in gw.__dict__ or gw._callback_server is None: | ||
gw.callback_server_parameters.eager_load = True | ||
gw.callback_server_parameters.daemonize = True | ||
gw.callback_server_parameters.daemonize_connections = True | ||
gw.callback_server_parameters.port = 0 | ||
gw.start_callback_server(gw.callback_server_parameters) | ||
cbport = gw._callback_server.server_socket.getsockname()[1] | ||
gw._callback_server.port = cbport | ||
# gateway with real port | ||
gw._python_proxy_port = gw._callback_server.port | ||
# get the GatewayServer object in JVM by ID | ||
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: we could remove this import in this file though. |
||
# update the port of CallbackClient with real port | ||
jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port) | ||
from pyspark.java_gateway import ensure_callback_server_started | ||
ensure_callback_server_started(gw) | ||
|
||
# register serializer for TransformFunction | ||
# it happens before creating SparkContext when loading from checkpointing | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.execution.streaming.sources | ||
|
||
import org.apache.spark.api.python.PythonException | ||
import org.apache.spark.sql._ | ||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder | ||
import org.apache.spark.sql.execution.streaming.Sink | ||
import org.apache.spark.sql.streaming.DataStreamWriter | ||
|
||
class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T]) | ||
extends Sink { | ||
|
||
override def addBatch(batchId: Long, data: DataFrame): Unit = { | ||
val resolvedEncoder = encoder.resolveAndBind( | ||
data.logicalPlan.output, | ||
data.sparkSession.sessionState.analyzer) | ||
val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag) | ||
val ds = data.sparkSession.createDataset(rdd)(encoder) | ||
batchWriter(ds, batchId) | ||
} | ||
|
||
override def toString(): String = "ForeachBatchSink" | ||
} | ||
|
||
|
||
/** | ||
* Interface that is meant to be extended by Python classes via Py4J. | ||
* Py4J allows Python classes to implement Java interfaces so that the JVM can call back | ||
* Python objects. In this case, this allows the user-defined Python `foreachBatch` function | ||
* to be called from JVM when the query is active. | ||
* */ | ||
trait PythonForeachBatchFunction { | ||
/** Call the Python implementation of this function */ | ||
def call(batchDF: DataFrame, batchId: Long): Unit | ||
} | ||
|
||
object PythonForeachBatchHelper { | ||
def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = { | ||
dsw.foreachBatch(pythonFunc.call _) | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,14 +21,15 @@ import java.util.Locale | |
|
||
import scala.collection.JavaConverters._ | ||
|
||
import org.apache.spark.annotation.InterfaceStability | ||
import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} | ||
import org.apache.spark.annotation.{InterfaceStability, Since} | ||
import org.apache.spark.api.java.function.VoidFunction2 | ||
import org.apache.spark.sql._ | ||
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes | ||
import org.apache.spark.sql.execution.command.DDLUtils | ||
import org.apache.spark.sql.execution.datasources.DataSource | ||
import org.apache.spark.sql.execution.streaming._ | ||
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger | ||
import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} | ||
import org.apache.spark.sql.execution.streaming.sources._ | ||
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} | ||
|
||
/** | ||
|
@@ -279,6 +280,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { | |
outputMode, | ||
useTempCheckpointLocation = true, | ||
trigger = trigger) | ||
} else if (source == "foreachBatch") { | ||
assertNotPartitioned("foreachBatch") | ||
if (trigger.isInstanceOf[ContinuousTrigger]) { | ||
throw new AnalysisException("'foreachBatch' is not supported with continuous trigger") | ||
} | ||
val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) | ||
df.sparkSession.sessionState.streamingQueryManager.startQuery( | ||
extraOptions.get("queryName"), | ||
extraOptions.get("checkpointLocation"), | ||
df, | ||
extraOptions.toMap, | ||
sink, | ||
outputMode, | ||
useTempCheckpointLocation = true, | ||
trigger = trigger) | ||
} else { | ||
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) | ||
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") | ||
|
@@ -322,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { | |
this | ||
} | ||
|
||
/** | ||
* :: Experimental :: | ||
* | ||
* (Scala-specific) Sets the output of the streaming query to be processed using the provided | ||
* function. This is supported only the in the micro-batch execution modes (that is, when the | ||
* trigger is not continuous). In every micro-batch, the provided function will be called in | ||
* every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. | ||
* The batchId can be used deduplicate and transactionally write the output | ||
* (that is, the provided Dataset) to external systems. The output Dataset is guaranteed | ||
* to exactly same for the same batchId (assuming all operations are deterministic in the query). | ||
* | ||
* @since 2.4.0 | ||
*/ | ||
@InterfaceStability.Evolving | ||
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's unclear that only can one of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. goood point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well... that is an existing problem because one can write the following confusion code
This will execute the foreach but it looks confusing nonetheless. In fact you can also do
This is a general existing problem that should be addressed in a different PR. |
||
this.source = "foreachBatch" | ||
if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") | ||
this.foreachBatchWriter = function | ||
this | ||
} | ||
|
||
/** | ||
* :: Experimental :: | ||
* | ||
* (Java-specific) Sets the output of the streaming query to be processed using the provided | ||
* function. This is supported only the in the micro-batch execution modes (that is, when the | ||
* trigger is not continuous). In every micro-batch, the provided function will be called in | ||
* every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. | ||
* The batchId can be used deduplicate and transactionally write the output | ||
* (that is, the provided Dataset) to external systems. The output Dataset is guaranteed | ||
* to exactly same for the same batchId (assuming all operations are deterministic in the query). | ||
* | ||
* @since 2.4.0 | ||
*/ | ||
@InterfaceStability.Evolving | ||
def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = { | ||
foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) | ||
} | ||
|
||
private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => | ||
cols.map(normalize(_, "Partition")) | ||
} | ||
|
@@ -358,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { | |
|
||
private var foreachWriter: ForeachWriter[T] = null | ||
|
||
private var foreachBatchWriter: (Dataset[T], Long) => Unit = null | ||
|
||
private var partitioningColumns: Option[Seq[String]] = None | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was copied verbatim from python streaming/context.py