Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24565][SS] Add API for in Structured Streaming for exposing output rows of each microbatch as a DataFrame #21571

Closed
wants to merge 10 commits into from
25 changes: 24 additions & 1 deletion python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
if sys.version >= '3':
xrange = range

from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters
from pyspark.find_spark_home import _find_spark_home
from pyspark.serializers import read_int, write_with_length, UTF8Deserializer

Expand Down Expand Up @@ -145,3 +145,26 @@ def do_server_auth(conn, auth_secret):
if reply != "ok":
conn.close()
raise Exception("Unexpected reply from iterator server.")


def ensure_callback_server_started(gw):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was copied verbatim from python streaming/context.py

"""
Start callback server if not already started. The callback server is needed if the Java
driver process needs to callback into the Python driver process to execute Python code.
"""

# getattr will fallback to JVM, so we cannot test by hasattr()
if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
gw.callback_server_parameters.eager_load = True
gw.callback_server_parameters.daemonize = True
gw.callback_server_parameters.daemonize_connections = True
gw.callback_server_parameters.port = 0
gw.start_callback_server(gw.callback_server_parameters)
cbport = gw._callback_server.server_socket.getsockname()[1]
gw._callback_server.port = cbport
# gateway with real port
gw._python_proxy_port = gw._callback_server.port
# get the GatewayServer object in JVM by ID
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
# update the port of CallbackClient with real port
jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)
33 changes: 32 additions & 1 deletion python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
else:
intlike = (int, long)

from py4j.java_gateway import java_import

from pyspark import since, keyword_only
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import _to_seq
from pyspark.sql.readwriter import OptionUtils, to_str
from pyspark.sql.types import *
from pyspark.sql.utils import StreamingQueryException
from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException

__all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"]

Expand Down Expand Up @@ -1016,6 +1018,35 @@ def func_with_open_process_close(partition_id, iterator):
self._jwrite.foreach(jForeachWriter)
return self

@since(2.4)
def foreachBatch(self, func):
"""
Sets the output of the streaming query to be processed using the provided
function. This is supported only the in the micro-batch execution modes (that is, when the
trigger is not continuous). In every micro-batch, the provided function will be called in
every micro-batch with (i) the output rows as a DataFrame and (ii) the batch identifier.
The batchId can be used deduplicate and transactionally write the output
(that is, the provided Dataset) to external systems. The output DataFrame is guaranteed
to exactly same for the same batchId (assuming all operations are deterministic in the
query).

.. note:: Evolving.

>>> def func(batch_df, batch_id):
... batch_df.collect()
...
>>> writer = sdf.writeStream.foreach(func)
"""

from pyspark.java_gateway import ensure_callback_server_started
gw = self._spark._sc._gateway
java_import(gw.jvm, "org.apache.spark.sql.execution.streaming.sources.*")

wrapped_func = ForeachBatchFunction(self._spark, func)
gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, wrapped_func)
ensure_callback_server_started(gw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be above otherwise there is a race that the streaming query calls this python func before the callback server is started.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not possible because the callback from JVM ForeachBatch sink to Python is made ONLY after the query is started. And the query cannot be started until this foreach() method finishes and start() is called.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry if I'm mistaken but can't we still put this above? Looks weird we ensure the callback server at the end.

return self

@ignore_unicode_prefix
@since(2.0)
def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None,
Expand Down
36 changes: 36 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,42 @@ class WriterWithNonCallableClose(WithProcess):
tester.assert_invalid_writer(WriterWithNonCallableClose(),
"'close' in provided object is not callable")

def test_streaming_foreachBatch(self):
q = None
collected = dict()

def collectBatch(batch_df, batch_id):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collectBatch -> collect_batch per PEP 8.

collected[batch_id] = batch_df.collect()

try:
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
q = df.writeStream.foreachBatch(collectBatch).start()
q.processAllAvailable()
self.assertTrue(0 in collected)
self.assertTrue(len(collected[0]), 2)
finally:
if q:
q.stop()

def test_streaming_foreachBatch_propagates_python_errors(self):
from pyspark.sql.utils import StreamingQueryException

q = None

def collectBatch(df, id):
raise Exception("this should fail the query")

try:
df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
q = df.writeStream.foreachBatch(collectBatch).start()
q.processAllAvailable()
self.fail("Expected a failure")
except StreamingQueryException as e:
self.assertTrue("this should fail" in str(e))
finally:
if q:
q.stop()

def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
Expand Down
23 changes: 23 additions & 0 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,26 @@ def require_minimum_pyarrow_version():
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
raise ImportError("PyArrow >= %s must be installed; however, "
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))


class ForeachBatchFunction(object):
"""
This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
the user-defined 'foreachBatch' function such that it can be called from the JVM when
the query is active.
"""

def __init__(self, sql_ctx, func):
self.sql_ctx = sql_ctx
self.func = func

def call(self, jdf, batch_id):
from pyspark.sql.dataframe import DataFrame
try:
self.func(DataFrame(jdf, self.sql_ctx), batch_id)
except Exception as e:
self.error = e
raise e

class Java:
implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction']
18 changes: 2 additions & 16 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,8 @@ def _ensure_initialized(cls):
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")

# start callback server
# getattr will fallback to JVM, so we cannot test by hasattr()
if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
gw.callback_server_parameters.eager_load = True
gw.callback_server_parameters.daemonize = True
gw.callback_server_parameters.daemonize_connections = True
gw.callback_server_parameters.port = 0
gw.start_callback_server(gw.callback_server_parameters)
cbport = gw._callback_server.server_socket.getsockname()[1]
gw._callback_server.port = cbport
# gateway with real port
gw._python_proxy_port = gw._callback_server.port
# get the GatewayServer object in JVM by ID
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we could remove this import in this file though.

# update the port of CallbackClient with real port
jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)
from pyspark.java_gateway import ensure_callback_server_started
ensure_callback_server_started(gw)

# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.api.python.PythonException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.DataStreamWriter

class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T])
extends Sink {

override def addBatch(batchId: Long, data: DataFrame): Unit = {
val resolvedEncoder = encoder.resolveAndBind(
data.logicalPlan.output,
data.sparkSession.sessionState.analyzer)
val rdd = data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag)
val ds = data.sparkSession.createDataset(rdd)(encoder)
batchWriter(ds, batchId)
}

override def toString(): String = "ForeachBatchSink"
}


/**
* Interface that is meant to be extended by Python classes via Py4J.
* Py4J allows Python classes to implement Java interfaces so that the JVM can call back
* Python objects. In this case, this allows the user-defined Python `foreachBatch` function
* to be called from JVM when the query is active.
* */
trait PythonForeachBatchFunction {
/** Call the Python implementation of this function */
def call(batchDF: DataFrame, batchId: Long): Unit
}

object PythonForeachBatchHelper {
def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = {
dsw.foreachBatch(pythonFunc.call _)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ import java.util.Locale

import scala.collection.JavaConverters._

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
import org.apache.spark.annotation.{InterfaceStability, Since}
import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2}
import org.apache.spark.sql.execution.streaming.sources._
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}

/**
Expand Down Expand Up @@ -279,6 +280,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
} else if (source == "foreachBatch") {
assertNotPartitioned("foreachBatch")
if (trigger.isInstanceOf[ContinuousTrigger]) {
throw new AnalysisException("'foreachBatch' is not supported with continuous trigger")
}
val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = true,
trigger = trigger)
} else {
val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
Expand Down Expand Up @@ -322,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
this
}

/**
* :: Experimental ::
*
* (Scala-specific) Sets the output of the streaming query to be processed using the provided
* function. This is supported only the in the micro-batch execution modes (that is, when the
* trigger is not continuous). In every micro-batch, the provided function will be called in
* every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier.
* The batchId can be used deduplicate and transactionally write the output
* (that is, the provided Dataset) to external systems. The output Dataset is guaranteed
* to exactly same for the same batchId (assuming all operations are deterministic in the query).
*
* @since 2.4.0
*/
@InterfaceStability.Evolving
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's unclear that only can one of foreachBatch and foreach be set. Reading from the doc, the user may think he can set both of them. Maybe we should disallow this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goood point.

Copy link
Contributor Author

@tdas tdas Jun 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well... that is an existing problem because one can write the following confusion code

df.writeStream.format("kafka").foreach(...).start()

This will execute the foreach but it looks confusing nonetheless. In fact you can also do

df.writeStream.format("kafka").format("bla").format("random")....

This is a general existing problem that should be addressed in a different PR.

this.source = "foreachBatch"
if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null")
this.foreachBatchWriter = function
this
}

/**
* :: Experimental ::
*
* (Java-specific) Sets the output of the streaming query to be processed using the provided
* function. This is supported only the in the micro-batch execution modes (that is, when the
* trigger is not continuous). In every micro-batch, the provided function will be called in
* every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier.
* The batchId can be used deduplicate and transactionally write the output
* (that is, the provided Dataset) to external systems. The output Dataset is guaranteed
* to exactly same for the same batchId (assuming all operations are deterministic in the query).
*
* @since 2.4.0
*/
@InterfaceStability.Evolving
def foreachBatch(function: VoidFunction2[Dataset[T], Long]): DataStreamWriter[T] = {
foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId))
}

private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols =>
cols.map(normalize(_, "Partition"))
}
Expand Down Expand Up @@ -358,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

private var foreachWriter: ForeachWriter[T] = null

private var foreachBatchWriter: (Dataset[T], Long) => Unit = null

private var partitioningColumns: Option[Seq[String]] = None
}
Loading