From 043ab9d5a6c1aa47822f86037d8de20573de945e Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Apr 2016 20:11:03 -0700 Subject: [PATCH 01/24] added python API for streaming dataframes --- python/pyspark/sql/dataframe.py | 12 ++ python/pyspark/sql/readwriter.py | 124 ++++++++++++++++++ python/pyspark/sql/streaming.py | 122 +++++++++++++++++ .../test_support/sql/streaming/text-test.txt | 2 + .../apache/spark/sql/ContinuousQuery.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 8 +- .../scala/org/apache/spark/sql/Trigger.scala | 30 ++--- 7 files changed, 281 insertions(+), 21 deletions(-) create mode 100644 python/pyspark/sql/streaming.py create mode 100644 python/test_support/sql/streaming/text-test.txt diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d473d6b534647..ce77c9a29f7e8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -197,6 +197,18 @@ def isLocal(self): """ return self._jdf.isLocal() + @property + @since(2.0) + def isStreaming(self): + """Returns true if this :class:`Dataset` contains one or more sources that continuously + return data as it arrives. A :class:`Dataset` that reads data from a streaming source + must be executed as a :class:`ContinuousQuery` using the :func:`startStream()` method in + :class:`DataFrameWriter`. Methods that return a single answer, (e.g., :func:`count` or + :func:`collect`) will throw an [[AnalysisException]] when there is a streaming + source present. + """ + return self._jdf.isStreaming() + @since(1.3) def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 0cef37e57cd54..06e382293d14b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -136,6 +136,32 @@ def load(self, path=None, format=None, schema=None, **options): else: return self._df(self._jreader.load()) + @since(2.0) + def stream(self, path=None, format=None, schema=None, **options): + """Loads a data stream from a data source and returns it as a :class`DataFrame`. + + :param path: optional string for file-system backed data sources. + :param format: optional string for format of the data source. Default to 'parquet'. + :param schema: optional :class:`StructType` for the input schema. + :param options: all other string options + + >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming') + >>> df.isStreaming + True + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + if type(path) != str or len(path.strip()) == 0: + raise ValueError("If the path is provided for stream, " +\ + "it needs to be a non-empty string. List of paths are not supported.") + return self._df(self._jreader.stream(path)) + else: + return self._df(self._jreader.stream()) + @since(1.4) def json(self, path, schema=None): """ @@ -334,6 +360,10 @@ def __init__(self, df): self._sqlContext = df.sql_ctx self._jwrite = df._jdf.write() + def _cq(self, jcq): + from pyspark.sql.streaming import ContinuousQuery + return ContinuousQuery(jcq, self._sqlContext) + @since(1.4) def mode(self, saveMode): """Specifies the behavior when data or table already exists. @@ -395,6 +425,37 @@ def partitionBy(self, *cols): self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) return self + @since(2.0) + def queryName(self, queryName): + """Specifies the name of the :class:`ContinuousQuery` that can be started with + :func:`startStream()`. This name must be unique among all the currently active queries + in the associated SQLContext. + + :param queryName: unique name for the query + + >>> sdf.write.queryName('streaming_query') + """ + if not queryName or type(queryName) != str or len(queryName.strip()) == 0: + raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName) + self._jwrite = self._jwrite.queryName(queryName) + return self + + @since(2.0) + def trigger(self, trigger): + """Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it + will run the query as fast as possible. + + :param trigger: a :class:`Trigger`. + >>> from pyspark.sql.streaming import ProcessingTime + >>> # trigger the query for execution every 5 seconds + >>> sdf.trigger(ProcessingTime('5 seconds')) + """ + from pyspark.sql.streaming import Trigger + if not trigger or issubclass(trigger, Trigger): + raise ValueError('The trigger must be of the Trigger class. Got: %s' % trigger) + self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) + return self + @since(1.4) def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. @@ -426,6 +487,67 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): else: self._jwrite.save(path) + @ignore_unicode_prefix + @since(2.0) + def startStream(self, path=None, format=None, mode=None, partitionBy=None, + queryName=None, checkpointLocation=None, trigger=None, **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``format`` and a set of ``options``. + If ``format`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + :param path: the path in a Hadoop supported file system + :param format: the format used to save + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + :param queryName: unique name for the query + :param trigger: Set the trigger for the stream query. The default value is + `ProcessingTime(0)` and it will run as fast as possible. + :param options: all other string options + + >>> temp = tempfile.mkdtemp() + >>> cq = sdf.write.format('text').startStream(os.path.join(temp, 'out'), + ... checkpointLocation=os.path.join(temp, 'chk')) + >>> cq.isActive + True + >>> cq.stop() + >>> cq.isActive + False + >>> from pyspark.sql.streaming import ProcessingTime + >>> cq = sdf.write.startStream(os.path.join(temp, 'out'), format='text', + ... queryName='my_query', trigger=ProcessingTime('5 seconds'), + ... checkpointLocation=os.path.join(temp, 'chk')) + >>> cq.name + 'my_query' + >>> cq.isActive + True + >>> cq.stop() + """ + self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + if queryName is not None: + self.queryName(queryName) + if checkpointLocation is not None: + if type(checkpointLocation) != str or len(checkpointLocation.strip()) == 0: + raise ValueError('The checkpointLocation must be a non-empty string. Got: %s' % + checkpointLocation) + self.option('checkpointLocation', checkpointLocation) + if trigger is not None: + self.trigger(trigger) + if path is None: + return self._cq(self._jwrite.startStream()) + else: + return self._cq(self._jwrite.startStream(path)) + @since(1.4) def insertInto(self, tableName, overwrite=False): """Inserts the content of the :class:`DataFrame` to the specified table. @@ -625,6 +747,8 @@ def _test(): globs['sqlContext'] = SQLContext(sc) globs['hiveContext'] = HiveContext(sc) globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + globs['sdf'] =\ + globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming') (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py new file mode 100644 index 0000000000000..c8a0fd3372cd3 --- /dev/null +++ b/python/pyspark/sql/streaming.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark import since + +__all__ = ["ContinuousQuery", "ProcessingTime"] + + +class ContinuousQuery(object): + """ + A handle to a query that is executing continuously in the background as new data arrives. + All these methods are thread-safe. + + .. note:: Experimental + + .. versionadded:: 2.0 + """ + + def __init__(self, jcq, sqlContext): + self._jcq = jcq + self._sqlContext = sqlContext + + @property + @since(2.0) + def name(self): + """The name of the continuous query. + """ + return self._jcq.name() + + @property + @since(2.0) + def isActive(self): + """Whether this continuous query is currently active or not. + """ + return self._jcq.isActive() + + @since(2.0) + def awaitTermination(self): + """Waits for the termination of `this` query, either by :func:`query.stop()` or by an + exception. If the query has terminated with an exception, then the exception will be thrown. + + If the query has terminated, then all subsequent calls to this method will either return + immediately (if the query was terminated by :func:`stop()`), or throw the exception + immediately (if the query has terminated with exception). + + throws ContinuousQueryException, if `this` query has terminated with an exception + """ + self._jcq.awaitTermination() + + @since(2.0) + def awaitTermination(self, timeoutMs): + """Waits for the termination of `this` query, either by :func:`query.stop()` or by an + exception. If the query has terminated with an exception, then the exception will be thrown. + Otherwise, it returns whether the query has terminated or not within the `timeoutMs` + milliseconds. + + If the query has terminated, then all subsequent calls to this method will either return + `true` immediately (if the query was terminated by :func:`stop()`), or throw the exception + immediately (if the query has terminated with exception). + + throws ContinuousQueryException, if `this` query has terminated with an exception + """ + if type(timeoutMs) != int or timeoutMs < 0: + raise ValueError("timeoutMs must be a positive integer. Got %s" % timeoutMs) + return self._jcq.awaitTermination(timeoutMs) + + + @since(2.0) + def stop(self): + """Stop this continuous query. + """ + self._jcq.stop() + + +class Trigger: + """Used to indicate how often results should be produced by a :class:`ContinuousQuery`. + + .. note:: Experimental + + .. versionadded:: 2.0 + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _to_java_trigger(self, sqlContext): pass + + +class ProcessingTime(Trigger): + """A trigger that runs a query periodically based on the processing time. If `interval` is 0, + the query will run as fast as possible. + + The interval should be given as a string, e.g. '2 seconds', '5 minutes', ... + + .. note:: Experimental + + .. versionadded:: 2.0 + """ + + def __init__(self, interval): + if interval is None or type(interval) != str or len(interval.strip()) == 0: + raise ValueError("interval should be a non empty interval string, e.g. '2 seconds'.") + self.interval = interval + + def _to_java_trigger(self, sqlContext): + return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) \ No newline at end of file diff --git a/python/test_support/sql/streaming/text-test.txt b/python/test_support/sql/streaming/text-test.txt new file mode 100644 index 0000000000000..ae1e76c9e93a7 --- /dev/null +++ b/python/test_support/sql/streaming/text-test.txt @@ -0,0 +1,2 @@ +hello +this \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala index d9973b092dc11..953169b63604f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -56,7 +56,7 @@ trait ContinuousQuery { * Returns current status of all the sources. * @since 2.0.0 */ - def sourceStatuses: Array[SourceStatus] + def sourceStatuses: Array[SourceStatus] /** Returns current status of the sink. */ def sinkStatus: SinkStatus @@ -77,7 +77,7 @@ trait ContinuousQuery { /** * Waits for the termination of `this` query, either by `query.stop()` or by an exception. - * If the query has terminated with an exception, then the exception will be throw. + * If the query has terminated with an exception, then the exception will be thrown. * Otherwise, it returns whether the query has terminated or not within the `timeoutMs` * milliseconds. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 54d250867fbb3..e6836cf90f710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -86,18 +86,18 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * Scala Example: * {{{ - * def.writer.trigger(ProcessingTime("10 seconds")) + * df.write.trigger(ProcessingTime("10 seconds")) * * import scala.concurrent.duration._ - * def.writer.trigger(ProcessingTime(10.seconds)) + * df.write.trigger(ProcessingTime(10.seconds)) * }}} * * Java Example: * {{{ - * def.writer.trigger(ProcessingTime.create("10 seconds")) + * df.write.trigger(ProcessingTime.create("10 seconds")) * * import java.util.concurrent.TimeUnit - * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala index c4e54b3f90ac5..828754ff3cdda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala @@ -35,28 +35,28 @@ sealed trait Trigger {} /** * :: Experimental :: - * A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0, + * A trigger that runs a query periodically based on the processing time. If `interval` is 0, * the query will run as fast as possible. * * Scala Example: * {{{ - * def.writer.trigger(ProcessingTime("10 seconds")) + * df.write.trigger(ProcessingTime("10 seconds")) * * import scala.concurrent.duration._ - * def.writer.trigger(ProcessingTime(10.seconds)) + * df.write.trigger(ProcessingTime(10.seconds)) * }}} * * Java Example: * {{{ - * def.writer.trigger(ProcessingTime.create("10 seconds")) + * df.write.trigger(ProcessingTime.create("10 seconds")) * * import java.util.concurrent.TimeUnit - * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} */ @Experimental -case class ProcessingTime(intervalMs: Long) extends Trigger { - require(intervalMs >= 0, "the interval of trigger should not be negative") +case class ProcessingTime(interval: Long) extends Trigger { + require(interval >= 0, "the interval of trigger should not be negative") } /** @@ -67,11 +67,11 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { object ProcessingTime { /** - * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. * * Example: * {{{ - * def.writer.trigger(ProcessingTime("10 seconds")) + * df.write.trigger(ProcessingTime("10 seconds")) * }}} */ def apply(interval: String): ProcessingTime = { @@ -94,12 +94,12 @@ object ProcessingTime { } /** - * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. * * Example: * {{{ * import scala.concurrent.duration._ - * def.writer.trigger(ProcessingTime(10.seconds)) + * df.write.trigger(ProcessingTime(10.seconds)) * }}} */ def apply(interval: Duration): ProcessingTime = { @@ -107,11 +107,11 @@ object ProcessingTime { } /** - * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. * * Example: * {{{ - * def.writer.trigger(ProcessingTime.create("10 seconds")) + * df.write.trigger(ProcessingTime.create("10 seconds")) * }}} */ def create(interval: String): ProcessingTime = { @@ -119,12 +119,12 @@ object ProcessingTime { } /** - * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. * * Example: * {{{ * import java.util.concurrent.TimeUnit - * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} */ def create(interval: Long, unit: TimeUnit): ProcessingTime = { From ce4171bfe75232efe39e08019aaf985e4fc2dfae Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Apr 2016 20:29:33 -0700 Subject: [PATCH 02/24] minor --- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/readwriter.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ce77c9a29f7e8..a1ad7d573e47a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -202,9 +202,9 @@ def isLocal(self): def isStreaming(self): """Returns true if this :class:`Dataset` contains one or more sources that continuously return data as it arrives. A :class:`Dataset` that reads data from a streaming source - must be executed as a :class:`ContinuousQuery` using the :func:`startStream()` method in + must be executed as a :class:`ContinuousQuery` using the :func:`startStream` method in :class:`DataFrameWriter`. Methods that return a single answer, (e.g., :func:`count` or - :func:`collect`) will throw an [[AnalysisException]] when there is a streaming + :func:`collect`) will throw an :class:`AnalysisException` when there is a streaming source present. """ return self._jdf.isStreaming() diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 06e382293d14b..d0150be462750 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -428,7 +428,7 @@ def partitionBy(self, *cols): @since(2.0) def queryName(self, queryName): """Specifies the name of the :class:`ContinuousQuery` that can be started with - :func:`startStream()`. This name must be unique among all the currently active queries + :func:`startStream`. This name must be unique among all the currently active queries in the associated SQLContext. :param queryName: unique name for the query @@ -446,6 +446,7 @@ def trigger(self, trigger): will run the query as fast as possible. :param trigger: a :class:`Trigger`. + >>> from pyspark.sql.streaming import ProcessingTime >>> # trigger the query for execution every 5 seconds >>> sdf.trigger(ProcessingTime('5 seconds')) @@ -509,6 +510,7 @@ def startStream(self, path=None, format=None, mode=None, partitionBy=None, :param queryName: unique name for the query :param trigger: Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run as fast as possible. + :param checkpointLocation: An optional location for checkpointing state and metadata. :param options: all other string options >>> temp = tempfile.mkdtemp() From 6ae7fd1864a14b1403c770f1b06d585a16f2d8d3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Apr 2016 20:36:22 -0700 Subject: [PATCH 03/24] fix pystyle --- python/pyspark/sql/readwriter.py | 4 ++-- python/pyspark/sql/streaming.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index d0150be462750..435d5fc51c523 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -156,8 +156,8 @@ def stream(self, path=None, format=None, schema=None, **options): self.options(**options) if path is not None: if type(path) != str or len(path.strip()) == 0: - raise ValueError("If the path is provided for stream, " +\ - "it needs to be a non-empty string. List of paths are not supported.") + raise ValueError("If the path is provided for stream, it needs to be a " + + "non-empty string. List of paths are not supported.") return self._df(self._jreader.stream(path)) else: return self._df(self._jreader.stream()) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index c8a0fd3372cd3..ffecdeb2945b7 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -80,7 +80,6 @@ def awaitTermination(self, timeoutMs): raise ValueError("timeoutMs must be a positive integer. Got %s" % timeoutMs) return self._jcq.awaitTermination(timeoutMs) - @since(2.0) def stop(self): """Stop this continuous query. @@ -99,7 +98,10 @@ class Trigger: __metaclass__ = ABCMeta @abstractmethod - def _to_java_trigger(self, sqlContext): pass + def _to_java_trigger(self, sqlContext): + """Internal method to construct the trigger on the jvm. + """ + pass class ProcessingTime(Trigger): @@ -119,4 +121,4 @@ def __init__(self, interval): self.interval = interval def _to_java_trigger(self, sqlContext): - return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) \ No newline at end of file + return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) From da63975b9d26968474aefddbbee25c346d3cdd7b Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 11 Apr 2016 23:23:43 -0700 Subject: [PATCH 04/24] fix test --- sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala index 828754ff3cdda..256e8a47a4665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala @@ -55,8 +55,8 @@ sealed trait Trigger {} * }}} */ @Experimental -case class ProcessingTime(interval: Long) extends Trigger { - require(interval >= 0, "the interval of trigger should not be negative") +case class ProcessingTime(intervalMs: Long) extends Trigger { + require(intervalMs >= 0, "the interval of trigger should not be negative") } /** From 96ac9f9f0172692bc093922a36906b834d1bf429 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 12 Apr 2016 10:14:57 -0700 Subject: [PATCH 05/24] fix py tests --- python/pyspark/sql/readwriter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 435d5fc51c523..ad102b10339d7 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -433,7 +433,7 @@ def queryName(self, queryName): :param queryName: unique name for the query - >>> sdf.write.queryName('streaming_query') + >>> writer = sdf.write.queryName('streaming_query') """ if not queryName or type(queryName) != str or len(queryName.strip()) == 0: raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName) @@ -449,7 +449,7 @@ def trigger(self, trigger): >>> from pyspark.sql.streaming import ProcessingTime >>> # trigger the query for execution every 5 seconds - >>> sdf.trigger(ProcessingTime('5 seconds')) + >>> writer = sdf.write.trigger(ProcessingTime('5 seconds')) """ from pyspark.sql.streaming import Trigger if not trigger or issubclass(trigger, Trigger): @@ -526,7 +526,7 @@ def startStream(self, path=None, format=None, mode=None, partitionBy=None, ... queryName='my_query', trigger=ProcessingTime('5 seconds'), ... checkpointLocation=os.path.join(temp, 'chk')) >>> cq.name - 'my_query' + u'my_query' >>> cq.isActive True >>> cq.stop() From 1fe20edd78cee9851c2a973b543171a22021cdf9 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 12 Apr 2016 13:01:54 -0700 Subject: [PATCH 06/24] fix object --- python/pyspark/sql/streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index ffecdeb2945b7..c87cb54bcd107 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -87,7 +87,7 @@ def stop(self): self._jcq.stop() -class Trigger: +class Trigger(object): """Used to indicate how often results should be produced by a :class:`ContinuousQuery`. .. note:: Experimental From 6dde6b85479f5992774b67b0baacf0da406ea8ab Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 14 Apr 2016 15:26:47 -0700 Subject: [PATCH 07/24] address comments --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 28 +++++++++------------------- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index ad102b10339d7..2855eabd071f0 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -452,7 +452,7 @@ def trigger(self, trigger): >>> writer = sdf.write.trigger(ProcessingTime('5 seconds')) """ from pyspark.sql.streaming import Trigger - if not trigger or issubclass(trigger, Trigger): + if not trigger or issubclass(type(trigger), Trigger): raise ValueError('The trigger must be of the Trigger class. Got: %s' % trigger) self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) return self diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index c87cb54bcd107..d4e2370246736 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -51,9 +51,11 @@ def isActive(self): return self._jcq.isActive() @since(2.0) - def awaitTermination(self): + def awaitTermination(self, timeoutMs=None): """Waits for the termination of `this` query, either by :func:`query.stop()` or by an exception. If the query has terminated with an exception, then the exception will be thrown. + If `timeoutMs` is set, it returns whether the query has terminated or not within the + `timeoutMs` milliseconds. If the query has terminated, then all subsequent calls to this method will either return immediately (if the query was terminated by :func:`stop()`), or throw the exception @@ -61,24 +63,12 @@ def awaitTermination(self): throws ContinuousQueryException, if `this` query has terminated with an exception """ - self._jcq.awaitTermination() - - @since(2.0) - def awaitTermination(self, timeoutMs): - """Waits for the termination of `this` query, either by :func:`query.stop()` or by an - exception. If the query has terminated with an exception, then the exception will be thrown. - Otherwise, it returns whether the query has terminated or not within the `timeoutMs` - milliseconds. - - If the query has terminated, then all subsequent calls to this method will either return - `true` immediately (if the query was terminated by :func:`stop()`), or throw the exception - immediately (if the query has terminated with exception). - - throws ContinuousQueryException, if `this` query has terminated with an exception - """ - if type(timeoutMs) != int or timeoutMs < 0: - raise ValueError("timeoutMs must be a positive integer. Got %s" % timeoutMs) - return self._jcq.awaitTermination(timeoutMs) + if timeoutMs is not None: + if type(timeoutMs) != int or timeoutMs < 0: + raise ValueError("timeoutMs must be a positive integer. Got %s" % timeoutMs) + return self._jcq.awaitTermination(timeoutMs) + else: + return self._jcq.awaitTermination() @since(2.0) def stop(self): From b95d6edfb18fb77eb78c0a4fceb1859bf2448ea9 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 14 Apr 2016 15:34:24 -0700 Subject: [PATCH 08/24] minor --- python/pyspark/sql/readwriter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2855eabd071f0..f965619db3eb9 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -442,8 +442,8 @@ def queryName(self, queryName): @since(2.0) def trigger(self, trigger): - """Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it - will run the query as fast as possible. + """Set the trigger for the stream query. If this is not set it will run the query as fast + as possible. :param trigger: a :class:`Trigger`. From 2e0a527acaae34ed9a0104894c1be9a309b5b227 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 14 Apr 2016 15:34:58 -0700 Subject: [PATCH 09/24] more --- python/pyspark/sql/readwriter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f965619db3eb9..f83117ffead19 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -443,7 +443,7 @@ def queryName(self, queryName): @since(2.0) def trigger(self, trigger): """Set the trigger for the stream query. If this is not set it will run the query as fast - as possible. + as possible, which is equivalent to setting the trigger to ``ProcessingTime('0 seconds')``. :param trigger: a :class:`Trigger`. From c55e605cfd05447448b399d8e4359834ba4f53f5 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 14 Apr 2016 16:53:07 -0700 Subject: [PATCH 10/24] register subclass --- python/pyspark/sql/streaming.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index d4e2370246736..dfb5c06165e72 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -112,3 +112,6 @@ def __init__(self, interval): def _to_java_trigger(self, sqlContext): return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) + + +Trigger.register(ProcessingTime) From 588ce1ff9c07c8db6b296e735a74e397b60983ea Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 14 Apr 2016 18:20:07 -0700 Subject: [PATCH 11/24] try this --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f83117ffead19..cc2b3e6671502 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -452,7 +452,7 @@ def trigger(self, trigger): >>> writer = sdf.write.trigger(ProcessingTime('5 seconds')) """ from pyspark.sql.streaming import Trigger - if not trigger or issubclass(type(trigger), Trigger): + if not trigger or Trigger._is_subclass(trigger): raise ValueError('The trigger must be of the Trigger class. Got: %s' % trigger) self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) return self diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index dfb5c06165e72..5dd94f2fed72e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -93,6 +93,10 @@ def _to_java_trigger(self, sqlContext): """ pass + @staticmethod + def _is_subclass(instance): + return isinstance(instance, ProcessingTime) + class ProcessingTime(Trigger): """A trigger that runs a query periodically based on the processing time. If `interval` is 0, @@ -112,6 +116,3 @@ def __init__(self, interval): def _to_java_trigger(self, sqlContext): return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) - - -Trigger.register(ProcessingTime) From 147e9f96d05220331cc931bbdd8d6192537e98f2 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 15 Apr 2016 09:17:35 -0700 Subject: [PATCH 12/24] fix check --- python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index cc2b3e6671502..6734c38c0ac70 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -452,7 +452,7 @@ def trigger(self, trigger): >>> writer = sdf.write.trigger(ProcessingTime('5 seconds')) """ from pyspark.sql.streaming import Trigger - if not trigger or Trigger._is_subclass(trigger): + if not trigger or not issubclass(type(trigger), Trigger): raise ValueError('The trigger must be of the Trigger class. Got: %s' % trigger) self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) return self diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 5dd94f2fed72e..d4e2370246736 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -93,10 +93,6 @@ def _to_java_trigger(self, sqlContext): """ pass - @staticmethod - def _is_subclass(instance): - return isinstance(instance, ProcessingTime) - class ProcessingTime(Trigger): """A trigger that runs a query periodically based on the processing time. If `interval` is 0, From 7e2aa431ea577fc6cd070efbbcf1a5ed8c31e004 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 17 Apr 2016 12:38:18 -0700 Subject: [PATCH 13/24] save prog --- python/pyspark/sql/readwriter.py | 49 ++++++++++++++------------------ python/pyspark/sql/tests.py | 7 +++++ 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index cc2b3e6671502..9e973d725e817 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -440,20 +440,27 @@ def queryName(self, queryName): self._jwrite = self._jwrite.queryName(queryName) return self + @keyword_only @since(2.0) - def trigger(self, trigger): + def trigger(self, processingTime=None): """Set the trigger for the stream query. If this is not set it will run the query as fast as possible, which is equivalent to setting the trigger to ``ProcessingTime('0 seconds')``. - :param trigger: a :class:`Trigger`. + :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. >>> from pyspark.sql.streaming import ProcessingTime >>> # trigger the query for execution every 5 seconds - >>> writer = sdf.write.trigger(ProcessingTime('5 seconds')) - """ - from pyspark.sql.streaming import Trigger - if not trigger or Trigger._is_subclass(trigger): - raise ValueError('The trigger must be of the Trigger class. Got: %s' % trigger) + >>> writer = sdf.write.trigger(processingTime='5 seconds') + """ + from pyspark.sql.streaming import ProcessingTime + trigger = None + if processingTime is not None: + if type(processingTime) != str or len(processingTime.strip()) == 0: + raise ValueError('The processing time must be a non empty string. Got: %s' % + processingTime) + trigger = ProcessingTime(processingTime) + if trigger is None: + raise ValueError('A trigger was not provided. Supported triggers: processingTime.') self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) return self @@ -491,7 +498,7 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): @ignore_unicode_prefix @since(2.0) def startStream(self, path=None, format=None, mode=None, partitionBy=None, - queryName=None, checkpointLocation=None, trigger=None, **options): + queryName=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. @@ -500,7 +507,6 @@ def startStream(self, path=None, format=None, mode=None, partitionBy=None, :param path: the path in a Hadoop supported file system :param format: the format used to save - :param mode: specifies the behavior of the save operation when data already exists. * ``append``: Append contents of this :class:`DataFrame` to existing data. * ``overwrite``: Overwrite existing data. @@ -508,43 +514,30 @@ def startStream(self, path=None, format=None, mode=None, partitionBy=None, * ``error`` (default case): Throw an exception if data already exists. :param partitionBy: names of partitioning columns :param queryName: unique name for the query - :param trigger: Set the trigger for the stream query. The default value is - `ProcessingTime(0)` and it will run as fast as possible. - :param checkpointLocation: An optional location for checkpointing state and metadata. - :param options: all other string options + :param options: All other string options. You may want to provide a `checkpointLocation` + for most streams, however it is not required for a `memory` stream. - >>> temp = tempfile.mkdtemp() - >>> cq = sdf.write.format('text').startStream(os.path.join(temp, 'out'), - ... checkpointLocation=os.path.join(temp, 'chk')) + >>> cq = sdf.write.format('memory').startStream() >>> cq.isActive True >>> cq.stop() >>> cq.isActive False - >>> from pyspark.sql.streaming import ProcessingTime - >>> cq = sdf.write.startStream(os.path.join(temp, 'out'), format='text', - ... queryName='my_query', trigger=ProcessingTime('5 seconds'), - ... checkpointLocation=os.path.join(temp, 'chk')) + >>> cq = sdf.write.trigger(processingTime='5 seconds').startStream( + ... queryName='my_query', format='memory') >>> cq.name u'my_query' >>> cq.isActive True >>> cq.stop() """ - self.mode(mode).options(**options) + self.options(**options) if partitionBy is not None: self.partitionBy(partitionBy) if format is not None: self.format(format) if queryName is not None: self.queryName(queryName) - if checkpointLocation is not None: - if type(checkpointLocation) != str or len(checkpointLocation.strip()) == 0: - raise ValueError('The checkpointLocation must be a non-empty string. Got: %s' % - checkpointLocation) - self.option('checkpointLocation', checkpointLocation) - if trigger is not None: - self.trigger(trigger) if path is None: return self._cq(self._jwrite.startStream()) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e4f79c911c0d9..8de51342b53cd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -879,6 +879,13 @@ def test_save_and_load_builder(self): shutil.rmtree(tmpPath) + def test_stream_save_options(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) From ed6dcb6770f08d21329170b66af47d016fadd074 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 17 Apr 2016 19:13:36 -0700 Subject: [PATCH 14/24] address michael's comments --- python/pyspark/sql/readwriter.py | 5 ++--- python/pyspark/sql/tests.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9e973d725e817..43f78d3600a34 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -497,9 +497,8 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): @ignore_unicode_prefix @since(2.0) - def startStream(self, path=None, format=None, mode=None, partitionBy=None, - queryName=None, **options): - """Saves the contents of the :class:`DataFrame` to a data source. + def startStream(self, path=None, format=None, partitionBy=None, queryName=None, **options): + """Streams the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. If ``format`` is not specified, the default data source configured by diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8de51342b53cd..bdb494a7054df 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -880,10 +880,17 @@ def test_save_and_load_builder(self): shutil.rmtree(tmpPath) def test_stream_save_options(self): - df = self.df + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) - + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + cq = df.write.option('checkpointLocation', chk)\ + .queryName("this_query").startStream(path=out, format='parquet') + self.assertEqual(cq.name, "this_query") + self.assertTrue(cq.isActive) + cq.stop() shutil.rmtree(tmpPath) def test_help_command(self): From c3464c0c684d4a76d5a0161192e01405b5518e03 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 18 Apr 2016 11:04:00 -0700 Subject: [PATCH 15/24] Update streaming.py --- python/pyspark/sql/streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index d4e2370246736..dd6c36542dc5e 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -19,7 +19,7 @@ from pyspark import since -__all__ = ["ContinuousQuery", "ProcessingTime"] +__all__ = ["ContinuousQuery"] class ContinuousQuery(object): From a552aef1fb122b54a38d42ad55b4a6b33dd48a5a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 18 Apr 2016 11:20:02 -0700 Subject: [PATCH 16/24] add keyword test --- python/pyspark/sql/readwriter.py | 1 + python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index fb879ea4165da..c5cac30503a55 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -23,6 +23,7 @@ from py4j.java_gateway import JavaClass from pyspark import RDD, since +from pyspark.ml.util import keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bdb494a7054df..9a266fd961eae 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -879,6 +879,15 @@ def test_save_and_load_builder(self): shutil.rmtree(tmpPath) + def test_stream_trigger_takes_keyword_args(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + try: + df.write.trigger('5 seconds') + self.fail("Should have thrown an exception") + except e as TypeError: + # should throw error + pass + def test_stream_save_options(self): df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') tmpPath = tempfile.mkdtemp() From 538f4104d3c9153fc9b562c35cbe6064b07d0b2d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 18 Apr 2016 11:45:33 -0700 Subject: [PATCH 17/24] move keyword_args --- python/pyspark/__init__.py | 14 ++++++++++++++ python/pyspark/ml/classification.py | 2 +- python/pyspark/ml/clustering.py | 2 +- python/pyspark/ml/evaluation.py | 3 +-- python/pyspark/ml/feature.py | 4 ++-- python/pyspark/ml/pipeline.py | 5 ++--- python/pyspark/ml/recommendation.py | 2 +- python/pyspark/ml/regression.py | 2 +- python/pyspark/ml/tests.py | 2 +- python/pyspark/ml/tuning.py | 4 ++-- python/pyspark/ml/util.py | 14 -------------- python/pyspark/sql/readwriter.py | 3 +-- 12 files changed, 27 insertions(+), 30 deletions(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 111ebaafee3e1..5f811cbd7f2a3 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -84,6 +84,20 @@ def copy_func(f, name=None, sinceversion=None, doc=None): return fn +def keyword_only(func): + """ + A decorator that forces keyword arguments in the wrapped method + and saves actual input keyword arguments in `_input_kwargs`. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > 1: + raise TypeError("Method %s forces keyword arguments." % func.__name__) + wrapper._input_kwargs = kwargs + return func(*args, **kwargs) + return wrapper + + # for back compatibility from pyspark.sql import SQLContext, HiveContext, Row diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index e64c7a392b93b..69115c77530f4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -17,7 +17,7 @@ import warnings -from pyspark import since +from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable from pyspark.ml.param import TypeConverters diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index f071c597c87f3..c997102685210 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,7 +15,7 @@ # limitations under the License. # -from pyspark import since +from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index c9b95b3bf45d9..a9dd7a45fe6af 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -17,11 +17,10 @@ from abc import abstractmethod, ABCMeta -from pyspark import since +from pyspark import since, keyword_only from pyspark.ml.wrapper import JavaWrapper from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol -from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 86b53285b5b00..2df68600c9b4f 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,10 +19,10 @@ if sys.version > '3': basestring = str -from pyspark import since +from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * -from pyspark.ml.util import keyword_only, JavaMLReadable, JavaMLWritable +from pyspark.ml.util import JavaMLReadable, JavaMLWritable from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2b5504bc2966a..28212b5af15b4 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -20,11 +20,10 @@ if sys.version > '3': basestring = str -from pyspark import SparkContext -from pyspark import since +from pyspark import since, keyword_only, SparkContext from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.common import inherit_doc diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 7c7a1b67a100e..8d9f008e9293f 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,7 +15,7 @@ # limitations under the License. # -from pyspark import since +from pyspark import since, keyword_only from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 1c18df3b27ab9..eb7c54d9e3a62 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -17,7 +17,7 @@ import warnings -from pyspark import since +from pyspark import since, keyword_only from pyspark.ml.param.shared import * from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2dcd5eeb52c21..1faafb47311bb 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -41,6 +41,7 @@ import tempfile import numpy as np +from pyspark import keyword_only from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier from pyspark.ml.clustering import KMeans @@ -50,7 +51,6 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * -from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ea8c61b7efe6c..8eeb5079778be 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -19,11 +19,11 @@ import numpy as np from pyspark import SparkContext -from pyspark import since +from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed -from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable from pyspark.ml.wrapper import JavaWrapper from pyspark.sql.functions import rand from pyspark.mllib.common import inherit_doc, _py2java diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index d4411fdfb9dde..22877816c32c2 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -38,20 +38,6 @@ def _jvm(): raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") -def keyword_only(func): - """ - A decorator that forces keyword arguments in the wrapped method - and saves actual input keyword arguments in `_input_kwargs`. - """ - @wraps(func) - def wrapper(*args, **kwargs): - if len(args) > 1: - raise TypeError("Method %s forces keyword arguments." % func.__name__) - wrapper._input_kwargs = kwargs - return func(*args, **kwargs) - return wrapper - - class Identifiable(object): """ Object with a unique ID. diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index c5cac30503a55..f7383207bf9c8 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -22,8 +22,7 @@ from py4j.java_gateway import JavaClass -from pyspark import RDD, since -from pyspark.ml.util import keyword_only +from pyspark import RDD, since, keyword_only from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import _to_seq from pyspark.sql.types import * From 0e0b10b580136eb7522fab38639a4aec1a8bcecc Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 18 Apr 2016 11:59:53 -0700 Subject: [PATCH 18/24] import wraps --- python/pyspark/__init__.py | 1 + python/pyspark/ml/util.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 5f811cbd7f2a3..ec1687415a7f6 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -37,6 +37,7 @@ """ +from functools import wraps import types from pyspark.conf import SparkConf diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index eb6b84843c5d4..7003e587ad0eb 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -17,7 +17,6 @@ import sys import uuid -from functools import wraps if sys.version > '3': basestring = str From fbe93c9c99e5469ebf5f05379d8154c6e48a4cdd Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 18 Apr 2016 13:31:18 -0700 Subject: [PATCH 19/24] Update tests.py --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9a266fd961eae..9514e21efd8ba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -884,7 +884,7 @@ def test_stream_trigger_takes_keyword_args(self): try: df.write.trigger('5 seconds') self.fail("Should have thrown an exception") - except e as TypeError: + except TypeError: # should throw error pass From 302da9bedb9eabe42ace6ee9b747ee89b5fed8fe Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 18 Apr 2016 14:53:16 -0700 Subject: [PATCH 20/24] Update readwriter.py --- python/pyspark/sql/readwriter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f7383207bf9c8..6c809d1139b2d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -515,16 +515,18 @@ def startStream(self, path=None, format=None, partitionBy=None, queryName=None, :param options: All other string options. You may want to provide a `checkpointLocation` for most streams, however it is not required for a `memory` stream. - >>> cq = sdf.write.format('memory').startStream() + >>> cq = sdf.write.format('memory').queryName('this_query').startStream() >>> cq.isActive True + >>> cq.name + u'this_query' >>> cq.stop() >>> cq.isActive False >>> cq = sdf.write.trigger(processingTime='5 seconds').startStream( - ... queryName='my_query', format='memory') + ... queryName='that_query', format='memory') >>> cq.name - u'my_query' + u'that_query' >>> cq.isActive True >>> cq.stop() From b78411430214f5e8bdf1c43f5dabb6d5be1d5a2e Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 19 Apr 2016 11:32:07 -0700 Subject: [PATCH 21/24] address comments --- python/pyspark/sql/tests.py | 62 +++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9a266fd961eae..e0805d7cf5bd1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -888,6 +888,22 @@ def test_stream_trigger_takes_keyword_args(self): # should throw error pass + def test_stream_read_options(self): + schema = StructType([StructField("data", StringType(), False)]) + df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\ + .schema(schema).stream() + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct") + + def test_stream_read_options_overwrite(self): + bad_schema = StructType([StructField("test", IntegerType(), False)]) + schema = StructType([StructField("data", StringType(), False)]) + df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \ + .schema(bad_schema).stream(path='python/test_support/sql/streaming', + schema=schema, format='text') + self.assertTrue(df.isStreaming) + self.assertEqual(df.schema.simpleString(), "struct") + def test_stream_save_options(self): df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') tmpPath = tempfile.mkdtemp() @@ -895,13 +911,53 @@ def test_stream_save_options(self): self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.option('checkpointLocation', chk)\ - .queryName("this_query").startStream(path=out, format='parquet') - self.assertEqual(cq.name, "this_query") + cq = df.write.option('checkpointLocation', chk).queryName('this_query')\ + .format('parquet').option('path', out).startStream() + self.assertEqual(cq.name, 'this_query') + self.assertTrue(cq.isActive) + cq.stop() + shutil.rmtree(tmpPath) + + def test_stream_save_options_overwrite(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + fake1 = os.path.join(tmpPath, 'fake1') + fake2 = os.path.join(tmpPath, 'fake2') + cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ + .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query', + checkpointLocation=chk) + self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) cq.stop() shutil.rmtree(tmpPath) + def test_stream_await_termination(self): + df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + self.assertTrue(df.isStreaming) + out = os.path.join(tmpPath, 'out') + chk = os.path.join(tmpPath, 'chk') + cq = df.write.startStream(path=out, format='parquet', queryName='this_query', + checkpointLocation=chk) + self.assertTrue(cq.isActive) + try: + cq.awaitTermination("hello") + self.fail("Expected a value exception") + except ValueError: + pass + now = time.time() + res = cq.awaitTermination(2600) # test should take at least 2 seconds + duration = time.time() - now + self.assertTrue(duration >= 2) + self.assertFalse(res) + cq.stop() + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) From c07d7959655c94aaf40f6623811d6364372ac1ef Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 19 Apr 2016 13:45:46 -0700 Subject: [PATCH 22/24] add process all available --- python/pyspark/sql/streaming.py | 10 ++++++++++ python/pyspark/sql/tests.py | 10 +++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index dd6c36542dc5e..549561669fdad 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -70,6 +70,16 @@ def awaitTermination(self, timeoutMs=None): else: return self._jcq.awaitTermination() + @since(2.0) + def processAllAvailable(self): + """Blocks until all available data in the source has been processed an committed to the + sink. This method is intended for testing. Note that in the case of continually arriving + data, this method may block forever. Additionally, this method is only guaranteed to block + until data that has been synchronously appended data to a stream source prior to invocation. + (i.e. `getOffset` must immediately reflect the addition). + """ + return self._jcq.processAllAvailable() + @since(2.0) def stop(self): """Stop this continuous query. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a441ae425b66e..2bdbdb4edcd19 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -915,6 +915,9 @@ def test_stream_save_options(self): .format('parquet').option('path', out).startStream() self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) + cq.processAllAvailable() + self.assertTrue(len([f for f in os.listdir(out) if 'parquet' in f]) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) cq.stop() shutil.rmtree(tmpPath) @@ -932,6 +935,11 @@ def test_stream_save_options_overwrite(self): checkpointLocation=chk) self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) + cq.processAllAvailable() + self.assertTrue(len([f for f in os.listdir(out) if 'parquet' in f]) > 0) + self.assertTrue(len(os.listdir(chk)) > 0) + self.assertTrue(len(os.listdir(fake1)) == 0) + self.assertTrue(len(os.listdir(fake2)) == 0) cq.stop() shutil.rmtree(tmpPath) @@ -951,7 +959,7 @@ def test_stream_await_termination(self): except ValueError: pass now = time.time() - res = cq.awaitTermination(2600) # test should take at least 2 seconds + res = cq.awaitTermination(2600) # test should take at least 2 seconds duration = time.time() - now self.assertTrue(duration >= 2) self.assertFalse(res) From 981f8e1ec45155f6d4b645fb136815d8e469ecb2 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 19 Apr 2016 15:05:46 -0700 Subject: [PATCH 23/24] fix test --- python/pyspark/sql/tests.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2bdbdb4edcd19..1343082cfe585 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -916,7 +916,10 @@ def test_stream_save_options(self): self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) cq.processAllAvailable() - self.assertTrue(len([f for f in os.listdir(out) if 'parquet' in f]) > 0) + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) self.assertTrue(len(os.listdir(chk)) > 0) cq.stop() shutil.rmtree(tmpPath) @@ -936,7 +939,10 @@ def test_stream_save_options_overwrite(self): self.assertEqual(cq.name, 'this_query') self.assertTrue(cq.isActive) cq.processAllAvailable() - self.assertTrue(len([f for f in os.listdir(out) if 'parquet' in f]) > 0) + output_files = [] + for _, _, files in os.walk(out): + output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + self.assertTrue(len(output_files) > 0) self.assertTrue(len(os.listdir(chk)) > 0) self.assertTrue(len(os.listdir(fake1)) == 0) self.assertTrue(len(os.listdir(fake2)) == 0) From 3d36543392d43077194703be66593767a786e433 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 19 Apr 2016 16:29:12 -0700 Subject: [PATCH 24/24] Update tests.py --- python/pyspark/sql/tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1343082cfe585..e4b7cbca86179 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -944,8 +944,8 @@ def test_stream_save_options_overwrite(self): output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) self.assertTrue(len(output_files) > 0) self.assertTrue(len(os.listdir(chk)) > 0) - self.assertTrue(len(os.listdir(fake1)) == 0) - self.assertTrue(len(os.listdir(fake2)) == 0) + self.assertFalse(os.path.isdir(fake1)) # should not have been created + self.assertFalse(os.path.isdir(fake2)) # should not have been created cq.stop() shutil.rmtree(tmpPath)